Skip to content

Commit

Permalink
Update/Fix Pipeline Mixins and ORT Pipelines (#2021)
Browse files Browse the repository at this point in the history
* created auto task mappings

* added correct auto classes

* created auto task mappings

* added correct auto classes

* added ort/auto diffusion classes

* fix ORTPipeline detection

* start test refactoring

* dynamic dtype

* support torch random numbers generator

* compact diffusion testing suite

* fix

* test

* test

* test

* use latent-consistency architecture name instead of lcm

* fix

* add ort diffusion pipeline tests

* added dummy objects

* remove duplicate code

* update stable diffusion mixin

* update latent consistency

* update sd for img2img

* update latent consistency

* update model parts to use frozen dict

* update tests and utils

* updated all mixins, enabled all tests ; all are passing except some reproducibility and comparaison tests (7 failed, 35 passed)

* fix sd xl hidden states

* style

* support testing without diffusers

* remove unnecessary

* revert

* export vae encoder by returning its latent distribution parameters

* fix the modeling to handle distributions

* create vae class to minimize changes in pipeline mixins

* remove unnecessary tests

* style

* style

* update diffusion models export test

* style

* fall back for when block_out_channels is not in vae config

* remove model parts from optimum.onnxruntime

* added .to to model parts

* remove custom mixins

* style

* Update optimum/exporters/onnx/model_configs.py

Co-authored-by: Ella Charlaix <[email protected]>

* Update optimum/exporters/onnx/model_configs.py

* conversion to numpy always work

* test adding two new pipelines

* remove duplicated tests

* match diffusers numpy input

* simplify model saving

* extend tests and only translate generators

* cleanup

* reduce parent model usage in model parts

* fix

* new tiny onnx diffusion model with configs

* model_save_path

* Update optimum/onnxruntime/modeling_diffusion.py

Co-authored-by: Ella Charlaix <[email protected]>

* migrate tiny-stable-diffusion-onnx

* resolve breaking change and mandatory arguments

* overwrite _get_add_time_ids

* fix

* remove inference calls from loading tests

* misc

* better compatibility between model parts and parent pipeline

* remove subfolder

* misc

* update

* support passing safety checker

* dummies

* remove the need for ORTPipeline

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
IlyasMoutawwakil and echarlaix authored Oct 9, 2024
1 parent d9754ab commit d3c56cd
Show file tree
Hide file tree
Showing 21 changed files with 914 additions and 3,410 deletions.
6 changes: 3 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:


class VaeEncoderOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-2
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand All @@ -1132,12 +1132,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
"latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
}


class VaeDecoderOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand Down
24 changes: 3 additions & 21 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@

from diffusers import (
DiffusionPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
Expand Down Expand Up @@ -92,27 +87,13 @@ def _get_submodels_for_export_diffusion(
Returns the components of a Stable Diffusion model.
"""

is_stable_diffusion = isinstance(
pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline)
)
is_stable_diffusion_xl = isinstance(
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
)
is_latent_consistency_model = isinstance(
pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline)
)

if is_stable_diffusion_xl:
projection_dim = pipeline.text_encoder_2.config.projection_dim
elif is_stable_diffusion:
projection_dim = pipeline.text_encoder.config.projection_dim
elif is_latent_consistency_model:
projection_dim = pipeline.text_encoder.config.projection_dim
else:
raise ValueError(
f"The export of a DiffusionPipeline model with the class name {pipeline.__class__.__name__} is currently not supported in Optimum. "
"Please open an issue or submit a PR to add the support."
)
projection_dim = pipeline.text_encoder.config.projection_dim

models_for_export = {}

Expand All @@ -139,7 +120,8 @@ def _get_submodels_for_export_diffusion(
vae_encoder = copy.deepcopy(pipeline.vae)
if not is_torch_greater_or_equal_than_2_1:
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)
vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()}
# we return the distribution parameters to be able to recreate it in the decoder
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
Expand Down
16 changes: 16 additions & 0 deletions optimum/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tu
return src_paths, dst_paths


def _get_model_external_data_paths(model_path: Path) -> List[Path]:
"""
Gets external data paths from the model.
"""

onnx_model = onnx.load(str(model_path), load_external_data=False)
model_tensors = _get_initializer_tensors(onnx_model)
# filter out tensors that are not external data
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return [model_path.parent / tensor_name for tensor_name in model_tensors_ext]


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""
Checks if the model uses external data.
Expand Down
8 changes: 8 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTStableDiffusionXLInpaintPipeline",
"ORTLatentConsistencyModelPipeline",
"ORTLatentConsistencyModelImg2ImgPipeline",
"ORTPipelineForImage2Image",
"ORTPipelineForInpainting",
"ORTPipelineForText2Image",
Expand All @@ -92,6 +94,8 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTStableDiffusionXLInpaintPipeline",
"ORTLatentConsistencyModelImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
"ORTPipelineForImage2Image",
"ORTPipelineForInpainting",
Expand Down Expand Up @@ -148,6 +152,7 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_diffusers_objects import (
ORTDiffusionPipeline,
ORTLatentConsistencyModelImg2ImgPipeline,
ORTLatentConsistencyModelPipeline,
ORTPipelineForImage2Image,
ORTPipelineForInpainting,
Expand All @@ -156,11 +161,13 @@
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLPipeline,
)
else:
from .modeling_diffusion import (
ORTDiffusionPipeline,
ORTLatentConsistencyModelImg2ImgPipeline,
ORTLatentConsistencyModelPipeline,
ORTPipelineForImage2Image,
ORTPipelineForInpainting,
Expand All @@ -169,6 +176,7 @@
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLPipeline,
)
else:
Expand Down
19 changes: 19 additions & 0 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ def dtype(self):

return None

def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
for arg in args:
if isinstance(arg, torch.device):
device = arg
elif isinstance(arg, torch.dtype):
dtype = arg

if device is not None and device != self.device:
raise ValueError(
"Cannot change the device of a model part without changing the device of the parent model. "
"Please use the `to` method of the parent model to change the device."
)

if dtype is not None and dtype != self.dtype:
raise NotImplementedError(
f"Cannot change the dtype of the model from {self.dtype} to {dtype}. "
f"Please export the model with the desired dtype."
)

@abstractmethod
def forward(self, *args, **kwargs):
pass
Expand Down
Loading

0 comments on commit d3c56cd

Please sign in to comment.