From 0d030409764a3f314fd0885be2a8fab5c47a942a Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 15 Jul 2024 11:31:56 +0200 Subject: [PATCH] feat(runner): add support for SD3-medium model (#118) This commit introduces support for the Stable Diffusion 3 Medium model from Hugging Face: [https://huggingface.co/stabilityai/stable-diffusion-3-medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium). Please be aware that this model has restrictive licensing at the time of writing and is not yet advised for public use. Ensure you read and understand the [licensing terms](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE) before enabling this model on your orchestrator. --- runner/app/pipelines/text_to_image.py | 26 +++- runner/app/pipelines/utils/__init__.py | 12 ++ runner/app/pipelines/utils/audio.py | 77 ++++++++++++ runner/app/pipelines/utils/utils.py | 158 +++++++++++++++++++++++++ runner/app/routes/text_to_image.py | 3 +- runner/dl_checkpoints.sh | 1 + runner/requirements.txt | 4 +- 7 files changed, 274 insertions(+), 7 deletions(-) create mode 100644 runner/app/pipelines/utils/__init__.py create mode 100644 runner/app/pipelines/utils/audio.py create mode 100644 runner/app/pipelines/utils/utils.py diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 278c04e5..0f9f4795 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -1,6 +1,7 @@ import logging import os from typing import List, Tuple, Optional +from enum import Enum import PIL import torch @@ -9,6 +10,7 @@ EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, + StableDiffusion3Pipeline, ) from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file @@ -24,7 +26,17 @@ logger = logging.getLogger(__name__) -SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning" + +class ModelName(Enum): + """Enumeration mapping model names to their corresponding IDs.""" + + SDXL_LIGHTNING = "ByteDance/SDXL-Lightning" + SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers" + + @classmethod + def list(cls): + """Return a list of all model IDs.""" + return list(map(lambda c: c.value, cls)) class TextToImagePipeline(Pipeline): @@ -46,7 +58,7 @@ def __init__(self, model_id: str): for _, _, files in os.walk(folder_path) for fname in files ) - or SDXL_LIGHTNING_MODEL_ID in model_id + or ModelName.SDXL_LIGHTNING.value in model_id ) if torch_device != "cpu" and has_fp16_variant: logger.info("TextToImagePipeline loading fp16 variant for %s", model_id) @@ -59,7 +71,7 @@ def __init__(self, model_id: str): kwargs["torch_dtype"] = torch.bfloat16 # Special case SDXL-Lightning because the unet for SDXL needs to be swapped - if SDXL_LIGHTNING_MODEL_ID in model_id: + if ModelName.SDXL_LIGHTNING.value in model_id: base = "stabilityai/stable-diffusion-xl-base-1.0" # ByteDance/SDXL-Lightning-2step @@ -81,7 +93,7 @@ def __init__(self, model_id: str): unet.load_state_dict( load_file( hf_hub_download( - SDXL_LIGHTNING_MODEL_ID, + ModelName.SDXL_LIGHTNING.value, f"{unet_id}.safetensors", cache_dir=kwargs["cache_dir"], ), @@ -96,6 +108,10 @@ def __init__(self, model_id: str): self.ldm.scheduler = EulerDiscreteScheduler.from_config( self.ldm.scheduler.config, timestep_spacing="trailing" ) + elif ModelName.SD3_MEDIUM.value in model_id: + self.ldm = StableDiffusion3Pipeline.from_pretrained(model_id, **kwargs).to( + torch_device + ) else: self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to( torch_device @@ -190,7 +206,7 @@ def __call__( # SD turbo models were trained without guidance_scale so # it should be set to 0 kwargs["guidance_scale"] = 0.0 - elif SDXL_LIGHTNING_MODEL_ID in self.model_id: + elif ModelName.SDXL_LIGHTNING.value in self.model_id: # SDXL-Lightning models should have guidance_scale = 0 and use # the correct number of inference steps for the unet checkpoint loaded kwargs["guidance_scale"] = 0.0 diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py new file mode 100644 index 00000000..a5e6f1eb --- /dev/null +++ b/runner/app/pipelines/utils/__init__.py @@ -0,0 +1,12 @@ +"""This module contains several utility functions that are used across the pipelines module.""" + +from app.pipelines.utils.utils import ( + get_model_dir, + get_model_path, + get_torch_device, + validate_torch_device, + is_lightning_model, + is_turbo_model, + get_temp_file, + SafetyChecker, +) diff --git a/runner/app/pipelines/utils/audio.py b/runner/app/pipelines/utils/audio.py new file mode 100644 index 00000000..5386fcbd --- /dev/null +++ b/runner/app/pipelines/utils/audio.py @@ -0,0 +1,77 @@ +"""This module provides functionality for converting audio files between different formats.""" + +from io import BytesIO + +import av +from fastapi import UploadFile + + +class AudioConversionError(Exception): + """Raised when an audio file cannot be converted.""" + + def __init__(self, message="Audio conversion failed."): + self.message = message + super().__init__(self.message) + + +class AudioConverter: + """Converts audio files to different formats.""" + + @staticmethod + def convert( + upload_file: UploadFile, output_extension: str, output_codec=None + ) -> bytes: + """Converts an audio file to a different format. + + Args: + upload_file: The audio file to convert. + output_extension: The desired output format. + output_codec: The desired output codec. + + Returns: + The converted audio file as bytes. + """ + if output_extension.startswith("."): + output_extension = output_extension.lstrip(".") + + output_buffer = BytesIO() + + input_container = av.open(upload_file.file) + output_container = av.open(output_buffer, mode="w", format=output_extension) + + try: + for stream in input_container.streams.audio: + audio_stream = output_container.add_stream( + output_codec if output_codec else output_extension + ) + + # Convert input audio to target format. + for frame in input_container.decode(stream): + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush remaining packets to the output. + for packet in audio_stream.encode(): + output_container.mux(packet) + except Exception as e: + raise AudioConversionError(f"Error during audio conversion: {e}") + finally: + input_container.close() + output_container.close() + + # Return the converted audio bytes. + output_buffer.seek(0) + converted_bytes = output_buffer.read() + return converted_bytes + + @staticmethod + def write_bytes_to_file(bytes: bytes, upload_file: UploadFile): + """Writes bytes to a file. + + Args: + bytes: The bytes to write. + upload_file: The file to write to. + """ + upload_file.file.seek(0) + upload_file.file.write(bytes) + upload_file.file.seek(0) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py new file mode 100644 index 00000000..58de41da --- /dev/null +++ b/runner/app/pipelines/utils/utils.py @@ -0,0 +1,158 @@ +"""This module contains several utility functions.""" + +import logging +import os +import re +import tempfile +import uuid +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from PIL import Image +from torch import dtype as TorchDtype +from transformers import CLIPFeatureExtractor + +logger = logging.getLogger(__name__) + + +def get_model_dir() -> Path: + return Path(os.environ["MODEL_DIR"]) + + +def get_model_path(model_id: str) -> Path: + return get_model_dir() / model_id.lower() + + +def get_torch_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + + +def validate_torch_device(device_name: str) -> bool: + """Checks if the given PyTorch device name is valid and available. + + Args: + device_name: Name of the device ('cuda:0', 'cuda', 'cpu'). + + Returns: + True if valid and available, False otherwise. + """ + try: + device = torch.device(device_name) + if device.type == "cuda": + # Check if CUDA is available and the specified index is within range + if device.index is None: + return torch.cuda.is_available() + else: + return device.index < torch.cuda.device_count() + return True + except RuntimeError: + return False + + +def is_lightning_model(model_id: str) -> bool: + """Checks if the model is a Lightning model. + + Args: + model_id: Model ID. + + Returns: + True if the model is a Lightning model, False otherwise. + """ + return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None + + +def is_turbo_model(model_id: str) -> bool: + """Checks if the model is a Turbo model. + + Args: + model_id: Model ID. + + Returns: + True if the model is a Turbo model, False otherwise. + """ + return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None + + +def get_temp_file(prefix: str, extension: str) -> str: + """Generates a temporary file path with the specified prefix and extension. + + Args: + prefix: The prefix for the temporary file. + extension: The extension for the temporary file. + + Returns: + The path to a non-existing temporary file with the specified prefix and extension. + """ + if not extension.startswith("."): + extension = "." + extension + filename = f"{prefix}{uuid.uuid4()}{extension}" + temp_path = os.path.join(tempfile.gettempdir(), filename) + while os.path.exists(temp_path): + filename = f"{prefix}{uuid.uuid4()}{extension}" + temp_path = os.path.join(tempfile.gettempdir(), filename) + return temp_path + + +class SafetyChecker: + """Checks images for unsafe or inappropriate content using a pretrained model. + + Attributes: + device (str): Device for inference. + """ + + def __init__( + self, + device: Optional[str] = "cuda", + dtype: Optional[TorchDtype] = torch.float16, + ): + """Initializes the SafetyChecker. + + Args: + device: Device for inference. Defaults to "cuda". + dtype: Data type for inference. Defaults to `torch.float16`. + """ + device = device.lower() if device else device + if not validate_torch_device(device): + default_device = get_torch_device() + logger.warning( + f"Device '{device}' not found. Defaulting to '{default_device}'." + ) + device = default_device + + self.device = device + self._dtype = dtype + self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ).to(self.device) + self._feature_extractor = CLIPFeatureExtractor.from_pretrained( + "openai/clip-vit-base-patch32" + ) + + def check_nsfw_images( + self, images: list[Image.Image] + ) -> tuple[list[Image.Image], list[bool]]: + """Checks images for unsafe content. + + Args: + images: Images to check. + + Returns: + Tuple of images and corresponding NSFW flags. + """ + safety_checker_input = self._feature_extractor(images, return_tensors="pt").to( + self.device + ) + images_np = [np.array(img) for img in images] + _, has_nsfw_concept = self._safety_checker( + images=images_np, + clip_input=safety_checker_input.pixel_values.to(self._dtype), + ) + return images, has_nsfw_concept diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3f52e36d..942baaff 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -68,7 +68,8 @@ async def text_to_image( for seed in seeds: try: params.seed = seed - imgs, nsfw_check = pipeline(**params.model_dump()) + kwargs = {k: v for k,v in params.model_dump().items() if k != "model_id"} + imgs, nsfw_check = pipeline(**kwargs) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) except Exception as e: diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 375d69c4..13902220 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -52,6 +52,7 @@ function download_all_models() { huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models diff --git a/runner/requirements.txt b/runner/requirements.txt index 7bd794c3..17b38644 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,4 +1,4 @@ -diffusers==0.28.0 +diffusers==0.29.2 accelerate==0.30.1 transformers==4.41.1 fastapi==0.111.0 @@ -14,3 +14,5 @@ deepcache==0.1.1 safetensors==0.4.3 scipy==1.13.0 numpy==1.26.4 +sentencepiece== 0.2.0 +protobuf==5.27.2