Skip to content

Commit

Permalink
Enable sdpa export for SD unet component (#1637)
Browse files Browse the repository at this point in the history
Enable sdpa export
  • Loading branch information
echarlaix committed Jan 19, 2024
1 parent 8b438ab commit a1e6583
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def _get_submodels_for_export_stable_diffusion(
models_for_export["text_encoder"] = pipeline.text_encoder

# U-NET
# PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention
pipeline.unet.set_attn_processor(AttnProcessor())
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
if not is_torch_greater_or_equal_than_2_1:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
Expand All @@ -141,14 +143,14 @@ def _get_submodels_for_export_stable_diffusion(

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = copy.deepcopy(pipeline.vae)
if not version.parse(torch.__version__) >= version.parse("2.1.0"):
if not is_torch_greater_or_equal_than_2_1:
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)
vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()}
models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = copy.deepcopy(pipeline.vae)
if not version.parse(torch.__version__) >= version.parse("2.1.0"):
if not is_torch_greater_or_equal_than_2_1:
vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder
Expand Down

0 comments on commit a1e6583

Please sign in to comment.