From c23e99ea76e9a2643e4947c0d0ab8577511bf121 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 15 Jul 2024 11:31:56 +0200 Subject: [PATCH] feat(runner): add support for SD3-medium model (#118) This commit introduces support for the Stable Diffusion 3 Medium model from Hugging Face: [https://huggingface.co/stabilityai/stable-diffusion-3-medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium). Please be aware that this model has restrictive licensing at the time of writing and is not yet advised for public use. Ensure you read and understand the [licensing terms](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE) before enabling this model on your orchestrator. --- runner/app/pipelines/text_to_image.py | 26 +++++++++++++++++++++----- runner/app/routes/text_to_image.py | 3 ++- runner/dl_checkpoints.sh | 1 + runner/requirements.txt | 4 +++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 278c04e5..0f9f4795 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -1,6 +1,7 @@ import logging import os from typing import List, Tuple, Optional +from enum import Enum import PIL import torch @@ -9,6 +10,7 @@ EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, + StableDiffusion3Pipeline, ) from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file @@ -24,7 +26,17 @@ logger = logging.getLogger(__name__) -SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning" + +class ModelName(Enum): + """Enumeration mapping model names to their corresponding IDs.""" + + SDXL_LIGHTNING = "ByteDance/SDXL-Lightning" + SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers" + + @classmethod + def list(cls): + """Return a list of all model IDs.""" + return list(map(lambda c: c.value, cls)) class TextToImagePipeline(Pipeline): @@ -46,7 +58,7 @@ def __init__(self, model_id: str): for _, _, files in os.walk(folder_path) for fname in files ) - or SDXL_LIGHTNING_MODEL_ID in model_id + or ModelName.SDXL_LIGHTNING.value in model_id ) if torch_device != "cpu" and has_fp16_variant: logger.info("TextToImagePipeline loading fp16 variant for %s", model_id) @@ -59,7 +71,7 @@ def __init__(self, model_id: str): kwargs["torch_dtype"] = torch.bfloat16 # Special case SDXL-Lightning because the unet for SDXL needs to be swapped - if SDXL_LIGHTNING_MODEL_ID in model_id: + if ModelName.SDXL_LIGHTNING.value in model_id: base = "stabilityai/stable-diffusion-xl-base-1.0" # ByteDance/SDXL-Lightning-2step @@ -81,7 +93,7 @@ def __init__(self, model_id: str): unet.load_state_dict( load_file( hf_hub_download( - SDXL_LIGHTNING_MODEL_ID, + ModelName.SDXL_LIGHTNING.value, f"{unet_id}.safetensors", cache_dir=kwargs["cache_dir"], ), @@ -96,6 +108,10 @@ def __init__(self, model_id: str): self.ldm.scheduler = EulerDiscreteScheduler.from_config( self.ldm.scheduler.config, timestep_spacing="trailing" ) + elif ModelName.SD3_MEDIUM.value in model_id: + self.ldm = StableDiffusion3Pipeline.from_pretrained(model_id, **kwargs).to( + torch_device + ) else: self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to( torch_device @@ -190,7 +206,7 @@ def __call__( # SD turbo models were trained without guidance_scale so # it should be set to 0 kwargs["guidance_scale"] = 0.0 - elif SDXL_LIGHTNING_MODEL_ID in self.model_id: + elif ModelName.SDXL_LIGHTNING.value in self.model_id: # SDXL-Lightning models should have guidance_scale = 0 and use # the correct number of inference steps for the unet checkpoint loaded kwargs["guidance_scale"] = 0.0 diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3f52e36d..942baaff 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -68,7 +68,8 @@ async def text_to_image( for seed in seeds: try: params.seed = seed - imgs, nsfw_check = pipeline(**params.model_dump()) + kwargs = {k: v for k,v in params.model_dump().items() if k != "model_id"} + imgs, nsfw_check = pipeline(**kwargs) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) except Exception as e: diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 375d69c4..13902220 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -52,6 +52,7 @@ function download_all_models() { huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models diff --git a/runner/requirements.txt b/runner/requirements.txt index 7bd794c3..17b38644 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,4 +1,4 @@ -diffusers==0.28.0 +diffusers==0.29.2 accelerate==0.30.1 transformers==4.41.1 fastapi==0.111.0 @@ -14,3 +14,5 @@ deepcache==0.1.1 safetensors==0.4.3 scipy==1.13.0 numpy==1.26.4 +sentencepiece== 0.2.0 +protobuf==5.27.2