Skip to content

Commit

Permalink
feat(model): add Realistic Vision model T2I support (livepeer#136)
Browse files Browse the repository at this point in the history
This commit ensures that the https://huggingface.co/SG161222/Realistic_Vision_V6.0_B1_noVAE
model is supported in the T2I pipeline.
  • Loading branch information
rickstaa authored Jul 30, 2024
1 parent 8c03423 commit 8da8ed0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.models import AutoencoderKL
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

Expand All @@ -32,6 +33,7 @@ class ModelName(Enum):

SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers"
REALISTIC_VISION_V6 = "SG161222/Realistic_Vision_V6.0_B1_noVAE"

@classmethod
def list(cls):
Expand Down Expand Up @@ -69,6 +71,11 @@ def __init__(self, model_id: str):
if os.environ.get("BFLOAT16"):
logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

# Load VAE for specific models.
if ModelName.REALISTIC_VISION_V6.value in model_id:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
kwargs["vae"] = vae

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
Expand Down
1 change: 1 addition & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ function download_all_models() {
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}
huggingface-cli download SG161222/Realistic_Vision_V6.0_B1_noVAE --include "*.fp16.safetensors" "*.json" "*.txt" "*.bin" --exclude ".onnx" ".onnx_data" --cache-dir models

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
Expand Down

0 comments on commit 8da8ed0

Please sign in to comment.