Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Latent Consistency models ONNX export #1469

Merged
merged 23 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,8 @@ The following ORT classes are available for the following custom tasks.

[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline
- __call__

#### ORTLatentConsistencyModelPipeline

[[autodoc]] onnxruntime.ORTLatentConsistencyModelPipeline
- __call__
16 changes: 16 additions & 0 deletions docs/source/onnxruntime/usage_guides/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,19 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
image.save("sailing_ship.png")
```



## Latent Consistency Models

### Text-to-Image

Here is an example of how you can load a Latent Consistency Models (LCMs) from [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) and run inference using ONNX Runtime :

```python
from optimum.onnxruntime import ORTLatentConsistencyModelPipeline

model_id = "SimianLuo/LCM_Dreamshaper_v7"
pipeline = ORTLatentConsistencyModelPipeline.from_pretrained(model_id, export=True)
prompt = "sailing ship in storm by Leonardo da Vinci"
images = pipeline(prompt, num_inference_steps=4, guidance_scale=8.0).images
```
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def ordered_inputs(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -
sig = inspect.signature(model.call)

for param in sig.parameters:
param_regex = re.compile(rf"{param}(\.\d*)?")
param_regex = re.compile(rf"{param}(\..*)?$")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this modification comes from timestep matching both timestep and timestep_cond previously (behavior that we don't want), we still want past_key_value to match past_key_values.0.key though

to_insert = []
for name, dynamic_axes in inputs.items():
if re.match(param_regex, name):
Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs["timestep_cond"] = {0: "batch_size"}
return common_inputs

@property
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,11 +1408,8 @@ def _infer_task_from_model_name_or_path(
)
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
if getattr(model_info, "library_name", None) == "diffusers":
# TODO : getattr(model_info, "model_index") defining auto_model_class_name currently set to None
for task in ("stable-diffusion-xl", "stable-diffusion"):
if task in model_info.tags:
inferred_task_name = task
break
class_name = model_info.config["diffusers"]["class_name"]
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
elif getattr(model_info, "library_name", None) == "timm":
inferred_task_name = "image-classification"
else:
Expand Down
4 changes: 4 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
]
else:
_import_structure["modeling_diffusion"] = [
Expand All @@ -86,6 +87,7 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
]


Expand Down Expand Up @@ -135,6 +137,7 @@
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_diffusers_objects import (
ORTLatentConsistencyModelPipeline,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand All @@ -143,6 +146,7 @@
)
else:
from .modeling_diffusion import (
ORTLatentConsistencyModelPipeline,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand Down
14 changes: 13 additions & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
Expand Down Expand Up @@ -501,6 +502,7 @@ def forward(
encoder_hidden_states: np.ndarray,
text_embeds: Optional[np.ndarray] = None,
time_ids: Optional[np.ndarray] = None,
timestep_cond: Optional[np.ndarray] = None,
):
onnx_inputs = {
"sample": sample,
Expand All @@ -512,7 +514,8 @@ def forward(
onnx_inputs["text_embeds"] = text_embeds
if time_ids is not None:
onnx_inputs["time_ids"] = time_ids

if timestep_cond is not None:
onnx_inputs["timestep_cond"] = timestep_cond
outputs = self.session.run(None, onnx_inputs)
return outputs

Expand Down Expand Up @@ -562,6 +565,15 @@ class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDi
__call__ = StableDiffusionInpaintPipelineMixin.__call__


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTLatentConsistencyModelPipeline(ORTStableDiffusionPipelineBase, LatentConsistencyPipelineMixin):
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
"""

__call__ = LatentConsistencyPipelineMixin.__call__


class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase):
auto_model_class = StableDiffusionXLImg2ImgPipeline

Expand Down
230 changes: 230 additions & 0 deletions optimum/pipelines/diffusers/pipeline_latent_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable, List, Optional, Union

import numpy as np
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from .pipeline_stable_diffusion import StableDiffusionPipelineMixin


logger = logging.getLogger(__name__)


class LatentConsistencyPipelineMixin(StableDiffusionPipelineMixin):
# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
guidance_scale: float = 8.5,
num_images_per_prompt: int = 1,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.

Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`Optional[int]`, defaults to None):
The height in pixels of the generated image.
width (`Optional[int]`, defaults to None):
The width in pixels of the generated image.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor

# Don't need to get negative prompts due to LCM guided distillation
negative_prompt = None
negative_prompt_embeds = None

# check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)

# define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

if generator is None:
generator = np.random

prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
False,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps, original_inference_steps=original_inference_steps)
timesteps = self.scheduler.timesteps

latents = self.prepare_latents(
batch_size * num_images_per_prompt,
self.unet.config["in_channels"],
height,
width,
prompt_embeds.dtype,
generator,
latents,
)

bs = batch_size * num_images_per_prompt
# get Guidance Scale Embedding
w = np.full(bs, guidance_scale - 1, dtype=prompt_embeds.dtype)
w_embedding = self.get_guidance_scale_embedding(
w, embedding_dim=self.unet.config["time_cond_proj_dim"], dtype=prompt_embeds.dtype
)

# Adapted from diffusers to extend it for other runtimes than ORT
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)

num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latents,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
timestep_cond=w_embedding,
)[0]

# compute the previous noisy sample x_t -> x_t-1
latents, denoised = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False
)
latents, denoised = latents.numpy(), denoised.numpy()

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
image = denoised
has_nsfw_concept = None
else:
denoised /= self.vae_decoder.config["scaling_factor"]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=denoised[i : i + 1])[0] for i in range(denoised.shape[0])]
)
image, has_nsfw_concept = self.run_safety_checker(image)

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=None):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings

Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
w = w * 1000
half_dim = embedding_dim // 2
emb = np.log(10000.0) / (half_dim - 1)
emb = np.exp(np.arange(half_dim, dtype=dtype) * -emb)
emb = w[:, None] * emb[None, :]
emb = np.concatenate([np.sin(emb), np.cos(emb)], axis=1)

if embedding_dim % 2 == 1: # zero pad
emb = np.pad(emb, [(0, 0), (0, 1)])

assert emb.shape == (w.shape[0], embedding_dim)
return emb
11 changes: 11 additions & 0 deletions optimum/utils/dummy_diffusers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["diffusers"])


class ORTLatentConsistencyModelPipeline(metaclass=DummyObject):
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
_backends = ["diffusers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["diffusers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["diffusers"])
Loading
Loading