diff --git a/runner/modal_app.py b/runner/modal_app.py index 1b6a3c99..683ccafb 100644 --- a/runner/modal_app.py +++ b/runner/modal_app.py @@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) +SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning" + @stub.function( image=downloader_image, volumes={models_path: models_volume}, timeout=3600 @@ -72,7 +74,11 @@ def __init__(self, pipeline: str, model_id: str): def enter(self): from app.main import load_pipeline - model_dir = "models--" + self.model_id.replace("/", "--") + model_id = self.model_id + if SDXL_LIGHTNING_MODEL_ID in self.model_id: + model_id = SDXL_LIGHTNING_MODEL_ID + + model_dir = "models--" + model_id.replace("/", "--") path = models_path / model_dir if not path.exists(): models_volume.reload() @@ -122,17 +128,17 @@ def text_to_image_sdxl_lightning_api(): @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) @asgi_app() -def text_to_image_sdxl_turbo_api(): - return make_api("text-to-image", "stabilityai/sdxl-turbo") +def text_to_image_sdxl_lightning_4step_api(): + return make_api("text-to-image", "ByteDance/SDXL-Lightning-4step") @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) @asgi_app() -def text_to_image_sd_1_5_api(): - return make_api("text-to-image", "runwayml/stable-diffusion-v1-5") +def text_to_image_sdxl_lightning_8step_api(): + return make_api("text-to-image", "ByteDance/SDXL-Lightning-8step") @stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) @asgi_app() -def text_to_image_sdxl_api(): - return make_api("text-to-image", "stabilityai/stable-diffusion-xl-base-1.0") +def text_to_image_sdxl_turbo_api(): + return make_api("text-to-image", "stabilityai/sdxl-turbo")