From 8838507f4c03779796853bab00edfc3db43cc0df Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 14 Mar 2024 22:05:55 +0100 Subject: [PATCH 1/3] feat(runner): add support for SDXL-Lightning in image-to-image pipelines This commit adds the [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning/blob/main/sdxl_lightning_samples.jpg) model to the image-to-image pipelines. --- runner/app/pipelines/image_to_image.py | 82 +++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 8 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index d4f70b5b..efa4a958 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,8 +1,14 @@ from app.pipelines.base import Pipeline from app.pipelines.util import get_torch_device, get_model_dir -from diffusers import AutoPipelineForImage2Image -from huggingface_hub import file_download +from diffusers import ( + AutoPipelineForImage2Image, + StableDiffusionXLPipeline, + UNet2DConditionModel, + EulerDiscreteScheduler, +) +from safetensors.torch import load_file +from huggingface_hub import file_download, hf_hub_download import torch import PIL from typing import List @@ -15,6 +21,8 @@ logger = logging.getLogger(__name__) +SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning" + class ImageToImagePipeline(Pipeline): def __init__(self, model_id: str): @@ -25,10 +33,13 @@ def __init__(self, model_id: str): repo_id=model_id, repo_type="model" ) folder_path = os.path.join(get_model_dir(), folder_name) - has_fp16_variant = any( - ".fp16.safetensors" in fname - for _, _, files in os.walk(folder_path) - for fname in files + has_fp16_variant = ( + any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + or SDXL_LIGHTNING_MODEL_ID in model_id ) if torch_device != "cpu" and has_fp16_variant: logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id) @@ -37,8 +48,49 @@ def __init__(self, model_id: str): kwargs["variant"] = "fp16" self.model_id = model_id - self.ldm = AutoPipelineForImage2Image.from_pretrained(model_id, **kwargs) - self.ldm.to(get_torch_device()) + + # Special case SDXL-Lightning because the unet for SDXL needs to be swapped + if SDXL_LIGHTNING_MODEL_ID in model_id: + base = "stabilityai/stable-diffusion-xl-base-1.0" + + # ByteDance/SDXL-Lightning-2step + if "2step" in model_id: + unet_id = "sdxl_lightning_2step_unet" + # ByteDance/SDXL-Lightning-4step + elif "4step" in model_id: + unet_id = "sdxl_lightning_4step_unet" + # ByteDance/SDXL-Lightning-8step + elif "8step" in model_id: + unet_id = "sdxl_lightning_8step_unet" + else: + # Default to 2step + unet_id = "sdxl_lightning_2step_unet" + + unet = UNet2DConditionModel.from_config( + base, subfolder="unet", cache_dir=kwargs["cache_dir"] + ).to(torch_device, kwargs["torch_dtype"]) + unet.load_state_dict( + load_file( + hf_hub_download( + SDXL_LIGHTNING_MODEL_ID, + f"{unet_id}.safetensors", + cache_dir=kwargs["cache_dir"], + ), + device=str(torch_device), + ) + ) + + self.ldm = StableDiffusionXLPipeline.from_pretrained( + base, unet=unet, **kwargs + ).to(torch_device) + + self.ldm.scheduler = EulerDiscreteScheduler.from_config( + self.ldm.scheduler.config, timestep_spacing="trailing" + ) + else: + self.ldm = AutoPipelineForImage2Image.from_pretrained( + model_id, **kwargs + ).to(torch_device) if os.environ.get("SFAST"): logger.info( @@ -76,6 +128,20 @@ def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: if "num_inference_steps" not in kwargs: kwargs["num_inference_steps"] = 2 + elif SDXL_LIGHTNING_MODEL_ID in self.model_id: + # SDXL-Lightning models should have guidance_scale = 0 and use + # the correct number of inference steps for the unet checkpoint loaded + kwargs["guidance_scale"] = 0.0 + + if "2step" in self.model_id: + kwargs["num_inference_steps"] = 2 + elif "4step" in self.model_id: + kwargs["num_inference_steps"] = 4 + elif "8step" in self.model_id: + kwargs["num_inference_steps"] = 8 + else: + # Default to 2step + kwargs["num_inference_steps"] = 2 return self.ldm(prompt, image=image, **kwargs).images From a5860dea674a55860e42153c7a0c1b37796e4e80 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 14 Mar 2024 23:45:15 +0100 Subject: [PATCH 2/3] feat(runner): add SDXL-Lightning image-to-image modal endpoints This commit adds the [modal](https://modal.com/) endpoints for the [SDXL_Lightning](https://huggingface.co/ByteDance/SDXL-Lightning/blob/main/sdxl_lightning_samples.jpg) image-to-image pipelines. --- runner/modal_app.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/runner/modal_app.py b/runner/modal_app.py index 7df2251a..410a607b 100644 --- a/runner/modal_app.py +++ b/runner/modal_app.py @@ -164,6 +164,24 @@ def text_to_image_sdxl_lightning_8step_api(): return make_api("text-to-image", "ByteDance/SDXL-Lightning-8step") +@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +@asgi_app() +def image_to_image_sdxl_lightning_api(): + return make_api("image-to-image", "ByteDance/SDXL-Lightning") + + +@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +@asgi_app() +def image_to_image_sdxl_lightning_4step_api(): + return make_api("image-to-image", "ByteDance/SDXL-Lightning-4step") + + +@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +@asgi_app() +def image_to_image_sdxl_lightning_8step_api(): + return make_api("image-to-image", "ByteDance/SDXL-Lightning-8step") + + @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) @asgi_app() def text_to_image_sdxl_turbo_api(): From 7334b3afadd80a8f97a4d1437490bfd4f80ed650 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 15 Mar 2024 21:27:04 +0100 Subject: [PATCH 3/3] refactor(runner): disable SVD 1.0 entpoint in modal deployment This commit comments out the `stabilityai/stable-video-diffusion-img2vid-xt` function to adhere to modal.com's starter plan which only includes 8 deployed endpoints. --- runner/modal_app.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/runner/modal_app.py b/runner/modal_app.py index 410a607b..9a5adf1f 100644 --- a/runner/modal_app.py +++ b/runner/modal_app.py @@ -188,12 +188,12 @@ def text_to_image_sdxl_turbo_api(): return make_api("text-to-image", "stabilityai/sdxl-turbo") -@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) -@asgi_app() -def image_to_video_svd_api(): - return make_api( - "image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt", "A100" - ) +# @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +# @asgi_app() +# def image_to_video_svd_api(): +# return make_api( +# "image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt", "A100" +# ) @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")])