diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index e2d6c692..0dc54278 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -21,6 +21,7 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.models import AutoencoderKL from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file @@ -32,6 +33,7 @@ class ModelName(Enum): SDXL_LIGHTNING = "ByteDance/SDXL-Lightning" SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers" + REALISTIC_VISION_V6 = "SG161222/Realistic_Vision_V6.0_B1_noVAE" @classmethod def list(cls): @@ -69,6 +71,11 @@ def __init__(self, model_id: str): if os.environ.get("BFLOAT16"): logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id) kwargs["torch_dtype"] = torch.bfloat16 + + # Load VAE for specific models. + if ModelName.REALISTIC_VISION_V6.value in model_id: + vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema") + kwargs["vae"] = vae # Special case SDXL-Lightning because the unet for SDXL needs to be swapped if ModelName.SDXL_LIGHTNING.value in model_id: diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 822590d4..9fe40837 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -56,6 +56,7 @@ function download_all_models() { huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} + huggingface-cli download SG161222/Realistic_Vision_V6.0_B1_noVAE --include "*.fp16.safetensors" "*.json" "*.txt" "*.bin" --exclude ".onnx" ".onnx_data" --cache-dir models # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models