diff --git a/scripts/convert_stable_diffusion_3_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_3_checkpoint_to_onnx.py new file mode 100644 index 000000000000..049b55cd8e33 --- /dev/null +++ b/scripts/convert_stable_diffusion_3_checkpoint_to_onnx.py @@ -0,0 +1,292 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import shutil +from pathlib import Path + +import onnx +import torch +from packaging import version +from torch.onnx import export + +from diffusers import OnnxRuntimeModel, OnnxStableDiffusion3Pipeline, StableDiffusion3Pipeline + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False): + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = "cuda" + elif fp16 and not torch.cuda.is_available(): + raise ValueError("`float16` model export is only supported on GPUs with CUDA") + else: + device = "cpu" + pipeline = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) + output_path = Path(output_path) + + # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=( + text_input.input_ids.to(device=device, dtype=torch.int32), + None, + None, + None, + True, + ), + output_path=output_path / "text_encoder" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output", "hidden_states"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + del pipeline.text_encoder + + num_tokens = pipeline.text_encoder_2.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder_2.config.hidden_size + text_input = pipeline.tokenizer_2( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer_2.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder_2, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=( + text_input.input_ids.to(device=device, dtype=torch.int32), + None, + None, + None, + True, + ), + output_path=output_path / "text_encoder_2" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output", "hidden_states"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + del pipeline.text_encoder_2 + + text_input = pipeline.tokenizer_3( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer_3.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder_3, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)), + output_path=output_path / "text_encoder_3" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + del pipeline.text_encoder_3 + + # TRANSFORMER + in_channels = pipeline.transformer.config.in_channels + sample_size = pipeline.transformer.config.sample_size + joint_attention_dim = pipeline.transformer.config.joint_attention_dim + pooled_projection_dim = pipeline.transformer.config.pooled_projection_dim + transformer_path = output_path / "transformer" / "model.onnx" + onnx_export( + pipeline.transformer, + model_args=( + torch.randn(2, in_channels, sample_size, sample_size).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, joint_attention_dim).to(device=device, dtype=dtype), + torch.randn(2, pooled_projection_dim).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + ), + output_path=transformer_path, + ordered_input_names=["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"], + output_names=["out_sample"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "hidden_states": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "encoder_hidden_states": {0: "batch", 1: "sequence", 2: "embed_dims"}, + "pooled_projections": {0: "batch", 1: "projection_dim"}, + "timestep": {0: "batch"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + model_path = str(transformer_path.absolute().as_posix()) + transformer_dir = os.path.dirname(model_path) + transformer = onnx.load(model_path) + # clean up existing tensor files + shutil.rmtree(transformer_dir) + os.mkdir(transformer_dir) + # collate external tensor files into one + onnx.save_model( + transformer, + model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) + del pipeline.transformer + + # VAE ENCODER + vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, sample_size, sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + del pipeline.vae + + onnx_pipeline = OnnxStableDiffusion3Pipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + text_encoder_2=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder_2"), + tokenizer_2=pipeline.tokenizer_2, + text_encoder_3=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder_3"), + tokenizer_3=pipeline.tokenizer_3, + transformer=OnnxRuntimeModel.from_pretrained(output_path / "transformer"), + scheduler=pipeline.scheduler, + ) + + onnx_pipeline.save_pretrained(output_path) + print("ONNX pipeline saved to", output_path) + + del pipeline + del onnx_pipeline + _ = OnnxStableDiffusion3Pipeline.from_pretrained(output_path, provider="CPUExecutionProvider") + print("ONNX pipeline is loadable") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--opset", + default=14, + type=int, + help="The version of the ONNX operator set to use.", + ) + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") + + args = parser.parse_args() + + convert_models(args.model_path, args.output_path, args.opset, args.fp16) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 789458a26299..c22bbb31c991 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -439,6 +439,7 @@ "OnnxStableDiffusionPipeline", "OnnxStableDiffusionUpscalePipeline", "StableDiffusionOnnxPipeline", + "OnnxStableDiffusion3Pipeline", ] ) @@ -878,6 +879,7 @@ OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, + OnnxStableDiffusion3Pipeline, ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7366520f4692..2c971043c201 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -352,6 +352,20 @@ ] ) +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) +else: + _import_structure["stable_diffusion_3"].extend( + [ + "OnnxStableDiffusion3Pipeline", + ] + ) + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() @@ -675,6 +689,16 @@ StableDiffusionOnnxPipeline, ) + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * + else: + from .stable_diffusion_3 import ( + OnnxStableDiffusion3Pipeline, + ) + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/stable_diffusion_3/__init__.py b/src/diffusers/pipelines/stable_diffusion_3/__init__.py index b0604589a208..467759445fa3 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_3/__init__.py @@ -8,6 +8,7 @@ is_flax_available, is_torch_available, is_transformers_available, + is_onnx_available, ) @@ -27,6 +28,18 @@ _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"] _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"] +try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["pipeline_onnx_stable_diffusion_3"] = [ + "OnnxStableDiffusion3Pipeline", + ] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -38,6 +51,14 @@ from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline + try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_onnx_objects import * + else: + from .pipeline_onnx_stable_diffusion_3 import OnnxStableDiffusion3Pipeline + else: import sys diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_onnx_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_onnx_stable_diffusion_3.py new file mode 100644 index 000000000000..d4fdd5be405d --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_onnx_stable_diffusion_3.py @@ -0,0 +1,701 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import ( + CLIPTokenizer, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ..onnx_utils import OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +class OnnxStableDiffusion3Pipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + text_encoder_2: OnnxRuntimeModel, + tokenizer_2: CLIPTokenizer, + text_encoder_3: OnnxRuntimeModel, + tokenizer_3: T5TokenizerFast, + transformer: OnnxRuntimeModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + dtype: Optional[torch.dtype] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return np.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + 4096, + ), + dtype=np.int32, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(input_ids=text_input_ids.astype(np.int32))[0] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = np.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.reshape(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32)) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds[-2] + else: + prompt_embeds = prompt_embeds[-(clip_skip + 2)] + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = np.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.reshape(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = np.tile(pooled_prompt_embeds, (1, num_images_per_prompt, 1)) + pooled_prompt_embeds = pooled_prompt_embeds.reshape(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = np.concatenate([prompt_embed, prompt_2_embed], axis=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if t5_prompt_embed.shape[-1] > clip_prompt_embeds.shape[-1]: + clip_prompt_embeds = np.pad( + clip_prompt_embeds, + pad_width=((0, 0), (0, 0), (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])), + ) + else: + clip_prompt_embeds = clip_prompt_embeds[..., :t5_prompt_embed.shape[-1]] + + prompt_embeds = np.concatenate([clip_prompt_embeds, t5_prompt_embed], axis=-2) + pooled_prompt_embeds = np.concatenate([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = np.concatenate([negative_prompt_embed, negative_prompt_2_embed], axis=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if t5_negative_prompt_embed.shape[-1] > negative_clip_prompt_embeds.shape[-1]: + negative_clip_prompt_embeds = np.pad( + negative_clip_prompt_embeds, + pad_width=((0, 0), (0, 0), (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1])), + ) + else: + negative_clip_prompt_embeds = negative_clip_prompt_embeds[..., :t5_negative_prompt_embed.shape[-1]] + + negative_prompt_embeds = np.concatenate([negative_clip_prompt_embeds, t5_negative_prompt_embed], axis=-2) + negative_pooled_prompt_embeds = np.concatenate( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], axis=-1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 512, + width: Optional[int] =512, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) + pooled_prompt_embeds = np.concatenate([negative_pooled_prompt_embeds, pooled_prompt_embeds], axis=0) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + if generator is None: + generator = np.random + latents_shape = ( + batch_size * num_images_per_prompt, + 16, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents_dtype = prompt_embeds.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + + # 6. Denoising loop + with self.progress_bar(self.scheduler.timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.array_split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False + )[0] + latents = scheduler_output.numpy() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + if i == len(timesteps) - 1 or (i + 1) % self.scheduler.order == 0: + progress_bar.update() + + latents = (latents / 1.5305) + 0.0609 + + image = self.vae_decoder(latent_sample=latents)[0] + image = self.image_processor.postprocess(torch.from_numpy(image), output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/tests/pipelines/stable_diffusion_3/test_onnx_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_onnx_pipeline_stable_diffusion_3.py new file mode 100644 index 000000000000..02379aaf3860 --- /dev/null +++ b/tests/pipelines/stable_diffusion_3/test_onnx_pipeline_stable_diffusion_3.py @@ -0,0 +1,296 @@ +import unittest + +import numpy as np +import shutil +import torch +from pathlib import Path +from torch.onnx import export + +from diffusers import StableDiffusion3Pipeline, OnnxStableDiffusion3Pipeline, OnnxRuntimeModel +from diffusers.utils.testing_utils import ( + torch_device, + is_onnx_available +) + +from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin + +if is_onnx_available(): + import onnxruntime as ort + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + export( + model, + model_args, + f=output_path, + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + ) + +def build_tiny_onnx_model(): + torch.manual_seed(0) + pipeline = StableDiffusion3Pipeline.from_pretrained("yujiepan/stable-diffusion-3-tiny-random").to(torch_device) + dtype = torch.float32 + + # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + model_args=( + text_input.input_ids.to(device=torch_device, dtype=torch.int32), + None, + None, + None, + True, + ), + output_path=Path("sd3/text_encoder/model.onnx"), + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output", "hidden_states"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + ) + del pipeline.text_encoder + + num_tokens = pipeline.text_encoder_2.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder_2.config.hidden_size + text_input = pipeline.tokenizer_2( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer_2.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder_2, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=( + text_input.input_ids.to(device=torch_device, dtype=torch.int32), + None, + None, + None, + True, + ), + output_path=Path("sd3/text_encoder_2/model.onnx"), + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output", "hidden_states"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + ) + del pipeline.text_encoder_2 + + text_input = pipeline.tokenizer_3( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer_3.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder_3, + model_args=(text_input.input_ids.to(device=torch_device, dtype=torch.int32)), + output_path=Path("sd3/text_encoder_3/model.onnx"), + ordered_input_names=["input_ids"], + output_names=["last_hidden_state"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + ) + del pipeline.text_encoder_3 + + # TRANSFORMER + in_channels = pipeline.transformer.config.in_channels + sample_size = pipeline.transformer.config.sample_size + joint_attention_dim = pipeline.transformer.config.joint_attention_dim + pooled_projection_dim = pipeline.transformer.config.pooled_projection_dim + transformer_path = Path("sd3/transformer/model.onnx") + onnx_export( + pipeline.transformer, + model_args=( + torch.randn(2, in_channels, sample_size, sample_size).to(device=torch_device, dtype=dtype), + torch.randn(2, num_tokens, joint_attention_dim).to(device=torch_device, dtype=dtype), + torch.randn(2, pooled_projection_dim).to(device=torch_device, dtype=dtype), + torch.randn(2).to(device=torch_device, dtype=dtype), + ), + output_path=transformer_path, + ordered_input_names=["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"], + output_names=["out_sample"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "hidden_states": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "encoder_hidden_states": {0: "batch", 1: "sequence", 2: "embed_dims"}, + "pooled_projections": {0: "batch", 1: "projection_dim"}, + "timestep": {0: "batch"}, + }, + ) + + # VAE ENCODER + vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=torch_device, dtype=dtype), + False, + ), + output_path=Path("sd3/vae_encoder/model.onnx"), + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, sample_size, sample_size).to(device=torch_device, dtype=dtype), + False, + ), + output_path=Path("sd3/vae_decoder/model.onnx"), + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + ) + del pipeline.vae + + onnx_pipeline = OnnxStableDiffusion3Pipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained("sd3/vae_encoder"), + vae_decoder=OnnxRuntimeModel.from_pretrained("sd3/vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained("sd3/text_encoder"), + tokenizer=pipeline.tokenizer, + text_encoder_2=OnnxRuntimeModel.from_pretrained("sd3/text_encoder_2"), + tokenizer_2=pipeline.tokenizer_2, + text_encoder_3=OnnxRuntimeModel.from_pretrained("sd3/text_encoder_3"), + tokenizer_3=pipeline.tokenizer_3, + transformer=OnnxRuntimeModel.from_pretrained("sd3/transformer"), + scheduler=pipeline.scheduler, + ) + + onnx_pipeline.save_pretrained("sd3") + +class OnnxStableDiffusion3PipelineFastTests(unittest.TestCase, OnnxPipelineTesterMixin): + pipeline_class = OnnxStableDiffusion3Pipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + checkpoint = "sd3" + + def get_dummy_inputs(self, device, seed=0): + generator = np.random.RandomState(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + } + return inputs + + @classmethod + def setUpClass(self): + build_tiny_onnx_model() + + @classmethod + def tearDownClass(self): + shutil.rmtree("sd3", ignore_errors=True) + + def test_onnx_stable_diffusion_3_different_prompts(self): + pipe = self.pipeline_class.from_pretrained(self.checkpoint) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + inputs["prompt_3"] = "another different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + assert max_diff > 1e-2 + + def test_onnx_stable_diffusion_3_different_negative_prompts(self): + pipe = self.pipeline_class.from_pretrained(self.checkpoint) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt_2"] = "deformed" + inputs["negative_prompt_3"] = "blurry" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + assert max_diff > 1e-2 + + def test_onnx_stable_diffusion_3_prompt_embeds(self): + pipe = self.pipeline_class.from_pretrained(self.checkpoint) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + do_classifier_free_guidance = inputs["guidance_scale"] > 1 + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipe.encode_prompt( + prompt, + prompt_2=None, + prompt_3=None, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4