Skip to content

Commit

Permalink
feat(runner): add support for SD3-medium model (#118)
Browse files Browse the repository at this point in the history
This commit introduces support for the Stable Diffusion 3 Medium model
from Hugging Face:
[https://huggingface.co/stabilityai/stable-diffusion-3-medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium).

Please be aware that this model has restrictive licensing at the time of
writing and is not yet advised for public use. Ensure you read and
understand the [licensing
terms](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)
before enabling this model on your orchestrator.
  • Loading branch information
rickstaa committed Jul 16, 2024
1 parent b059e9b commit 0d03040
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 7 deletions.
26 changes: 21 additions & 5 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from typing import List, Tuple, Optional
from enum import Enum

import PIL
import torch
Expand All @@ -9,6 +10,7 @@
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
StableDiffusion3Pipeline,
)
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file
Expand All @@ -24,7 +26,17 @@

logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers"

@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))


class TextToImagePipeline(Pipeline):
Expand All @@ -46,7 +58,7 @@ def __init__(self, model_id: str):
for _, _, files in os.walk(folder_path)
for fname in files
)
or SDXL_LIGHTNING_MODEL_ID in model_id
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("TextToImagePipeline loading fp16 variant for %s", model_id)
Expand All @@ -59,7 +71,7 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.bfloat16

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if SDXL_LIGHTNING_MODEL_ID in model_id:
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"

# ByteDance/SDXL-Lightning-2step
Expand All @@ -81,7 +93,7 @@ def __init__(self, model_id: str):
unet.load_state_dict(
load_file(
hf_hub_download(
SDXL_LIGHTNING_MODEL_ID,
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
Expand All @@ -96,6 +108,10 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.SD3_MEDIUM.value in model_id:
self.ldm = StableDiffusion3Pipeline.from_pretrained(model_id, **kwargs).to(
torch_device
)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
Expand Down Expand Up @@ -190,7 +206,7 @@ def __call__(
# SD turbo models were trained without guidance_scale so
# it should be set to 0
kwargs["guidance_scale"] = 0.0
elif SDXL_LIGHTNING_MODEL_ID in self.model_id:
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
Expand Down
12 changes: 12 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This module contains several utility functions that are used across the pipelines module."""

from app.pipelines.utils.utils import (
get_model_dir,
get_model_path,
get_torch_device,
validate_torch_device,
is_lightning_model,
is_turbo_model,
get_temp_file,
SafetyChecker,
)
77 changes: 77 additions & 0 deletions runner/app/pipelines/utils/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""This module provides functionality for converting audio files between different formats."""

from io import BytesIO

import av
from fastapi import UploadFile


class AudioConversionError(Exception):
"""Raised when an audio file cannot be converted."""

def __init__(self, message="Audio conversion failed."):
self.message = message
super().__init__(self.message)


class AudioConverter:
"""Converts audio files to different formats."""

@staticmethod
def convert(
upload_file: UploadFile, output_extension: str, output_codec=None
) -> bytes:
"""Converts an audio file to a different format.
Args:
upload_file: The audio file to convert.
output_extension: The desired output format.
output_codec: The desired output codec.
Returns:
The converted audio file as bytes.
"""
if output_extension.startswith("."):
output_extension = output_extension.lstrip(".")

output_buffer = BytesIO()

input_container = av.open(upload_file.file)
output_container = av.open(output_buffer, mode="w", format=output_extension)

try:
for stream in input_container.streams.audio:
audio_stream = output_container.add_stream(
output_codec if output_codec else output_extension
)

# Convert input audio to target format.
for frame in input_container.decode(stream):
for packet in audio_stream.encode(frame):
output_container.mux(packet)

# Flush remaining packets to the output.
for packet in audio_stream.encode():
output_container.mux(packet)
except Exception as e:
raise AudioConversionError(f"Error during audio conversion: {e}")
finally:
input_container.close()
output_container.close()

# Return the converted audio bytes.
output_buffer.seek(0)
converted_bytes = output_buffer.read()
return converted_bytes

@staticmethod
def write_bytes_to_file(bytes: bytes, upload_file: UploadFile):
"""Writes bytes to a file.
Args:
bytes: The bytes to write.
upload_file: The file to write to.
"""
upload_file.file.seek(0)
upload_file.file.write(bytes)
upload_file.file.seek(0)
158 changes: 158 additions & 0 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""This module contains several utility functions."""

import logging
import os
import re
import tempfile
import uuid
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torch import dtype as TorchDtype
from transformers import CLIPFeatureExtractor

logger = logging.getLogger(__name__)


def get_model_dir() -> Path:
return Path(os.environ["MODEL_DIR"])


def get_model_path(model_id: str) -> Path:
return get_model_dir() / model_id.lower()


def get_torch_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")


def validate_torch_device(device_name: str) -> bool:
"""Checks if the given PyTorch device name is valid and available.
Args:
device_name: Name of the device ('cuda:0', 'cuda', 'cpu').
Returns:
True if valid and available, False otherwise.
"""
try:
device = torch.device(device_name)
if device.type == "cuda":
# Check if CUDA is available and the specified index is within range
if device.index is None:
return torch.cuda.is_available()
else:
return device.index < torch.cuda.device_count()
return True
except RuntimeError:
return False


def is_lightning_model(model_id: str) -> bool:
"""Checks if the model is a Lightning model.
Args:
model_id: Model ID.
Returns:
True if the model is a Lightning model, False otherwise.
"""
return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None


def is_turbo_model(model_id: str) -> bool:
"""Checks if the model is a Turbo model.
Args:
model_id: Model ID.
Returns:
True if the model is a Turbo model, False otherwise.
"""
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None


def get_temp_file(prefix: str, extension: str) -> str:
"""Generates a temporary file path with the specified prefix and extension.
Args:
prefix: The prefix for the temporary file.
extension: The extension for the temporary file.
Returns:
The path to a non-existing temporary file with the specified prefix and extension.
"""
if not extension.startswith("."):
extension = "." + extension
filename = f"{prefix}{uuid.uuid4()}{extension}"
temp_path = os.path.join(tempfile.gettempdir(), filename)
while os.path.exists(temp_path):
filename = f"{prefix}{uuid.uuid4()}{extension}"
temp_path = os.path.join(tempfile.gettempdir(), filename)
return temp_path


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Attributes:
device (str): Device for inference.
"""

def __init__(
self,
device: Optional[str] = "cuda",
dtype: Optional[TorchDtype] = torch.float16,
):
"""Initializes the SafetyChecker.
Args:
device: Device for inference. Defaults to "cuda".
dtype: Data type for inference. Defaults to `torch.float16`.
"""
device = device.lower() if device else device
if not validate_torch_device(device):
default_device = get_torch_device()
logger.warning(
f"Device '{device}' not found. Defaulting to '{default_device}'."
)
device = default_device

self.device = device
self._dtype = dtype
self._safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(self.device)
self._feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)

def check_nsfw_images(
self, images: list[Image.Image]
) -> tuple[list[Image.Image], list[bool]]:
"""Checks images for unsafe content.
Args:
images: Images to check.
Returns:
Tuple of images and corresponding NSFW flags.
"""
safety_checker_input = self._feature_extractor(images, return_tensors="pt").to(
self.device
)
images_np = [np.array(img) for img in images]
_, has_nsfw_concept = self._safety_checker(
images=images_np,
clip_input=safety_checker_input.pixel_values.to(self._dtype),
)
return images, has_nsfw_concept
3 changes: 2 additions & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ async def text_to_image(
for seed in seeds:
try:
params.seed = seed
imgs, nsfw_check = pipeline(**params.model_dump())
kwargs = {k: v for k,v in params.model_dump().items() if k != "model_id"}
imgs, nsfw_check = pipeline(**kwargs)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ function download_all_models() {
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
Expand Down
4 changes: 3 additions & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
diffusers==0.28.0
diffusers==0.29.2
accelerate==0.30.1
transformers==4.41.1
fastapi==0.111.0
Expand All @@ -14,3 +14,5 @@ deepcache==0.1.1
safetensors==0.4.3
scipy==1.13.0
numpy==1.26.4
sentencepiece== 0.2.0
protobuf==5.27.2

0 comments on commit 0d03040

Please sign in to comment.