Skip to content

Commit

Permalink
conditional sd3 and flux modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 19, 2024
1 parent 813002d commit 2183fb7
Showing 1 changed file with 62 additions and 37 deletions.
99 changes: 62 additions & 37 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
FluxPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3Pipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
Expand Down Expand Up @@ -955,48 +951,78 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi
auto_model_class = LatentConsistencyModelImg2ImgPipeline


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Pipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusion3Pipeline).
"""
class ORTUnavailablePipeline:
MIN_VERSION = None

main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = StableDiffusion3Pipeline
def __init__(self, *args, **kwargs):
raise NotImplementedError(
f"The pipeline {self.__class__.__name__} is not available in the current version of `diffusers`. "
f"Please upgrade `diffusers` to {self.MIN_VERSION} or later."
)


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Img2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusion3Img2ImgPipeline).
"""
if check_if_diffusers_greater("0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline

main_input_name = "image"
export_feature = "image-to-image"
auto_model_class = StableDiffusion3Img2ImgPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Pipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusion3Pipeline).
"""

main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = StableDiffusion3Pipeline

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3InpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusion3InpaintPipeline).
"""
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Img2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusion3Img2ImgPipeline).
"""

main_input_name = "prompt"
export_feature = "inpainting"
auto_model_class = StableDiffusion3InpaintPipeline
main_input_name = "image"
export_feature = "image-to-image"
auto_model_class = StableDiffusion3Img2ImgPipeline

else:

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.FluxPipeline](https://huggingface.co/docs/diffusers/api/pipelines/flux/text2img#diffusers.FluxPipeline).
"""
class ORTStableDiffusion3Pipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.29.0"

main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = FluxPipeline
class ORTStableDiffusion3Img2ImgPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.29.0"


if check_if_diffusers_greater("0.30.0"):
from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3InpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusion3InpaintPipeline).
"""

main_input_name = "prompt"
export_feature = "inpainting"
auto_model_class = StableDiffusion3InpaintPipeline

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.FluxPipeline](https://huggingface.co/docs/diffusers/api/pipelines/flux/text2img#diffusers.FluxPipeline).
"""

main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = FluxPipeline

else:

class ORTStableDiffusion3InpaintPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.30.0"

class ORTFluxPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.30.0"


SUPPORTED_ORT_PIPELINES = [
Expand Down Expand Up @@ -1049,7 +1075,6 @@ def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tr
ORT_INPAINT_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", ORTStableDiffusionInpaintPipeline),
("stable-diffusion-3", ORTStableDiffusion3InpaintPipeline),
("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline),
]
)
Expand Down

0 comments on commit 2183fb7

Please sign in to comment.