From 5d02d5a12febba9fb88dc34ec6aa78e24c77f1a8 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 4 Jun 2024 14:54:16 +0200 Subject: [PATCH] feat(runner): disable DEEPCACHE for lightning/turbo models (#93) * feat(runner): disable DEEPCACHE for lightning/turbo models This commit ensures that people can not use the deepcache optimization with lightning and turbo models. As explained in https://github.com/livepeer/ai-worker/issues/82#issuecomment-2141983903 this optimization does not offer any speedup for these models while it does reduce image quality. * fix: ensure correct lighthing/turbo matching This commit ensures that Lightning/Turbo model names with both the `-` and `_` are matched when checking whether the orchestrator loaded a Lightning/Turbo model. --- runner/app/pipelines/image_to_image.py | 20 +++++++++++++++++--- runner/app/pipelines/text_to_image.py | 12 ++++++++++-- runner/app/pipelines/util.py | 25 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 9ddf687a..a2ec720f 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,5 +1,11 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker +from app.pipelines.util import ( + get_torch_device, + get_model_dir, + SafetyChecker, + is_lightning_model, + is_turbo_model, +) from diffusers import ( AutoPipelineForImage2Image, @@ -119,14 +125,22 @@ def __init__(self, model_id: str): "call may be slow if 'SFAST' is enabled." ) - if deepcache_enabled: + if deepcache_enabled and not ( + is_lightning_model(model_id) or is_turbo_model(model_id) + ): logger.info( - "TextToImagePipeline will be optimized with DeepCache for %s", + "ImageToImagePipeline will be optimized with DeepCache for %s", model_id, ) from app.pipelines.optim.deepcache import enable_deepcache self.ldm = enable_deepcache(self.ldm) + elif deepcache_enabled: + logger.warning( + "DeepCache is not supported for Lightning or Turbo models. " + "ImageToImagePipeline will NOT be optimized with DeepCache for %s", + model_id, + ) safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() self._safety_checker = SafetyChecker(device=safety_checker_device) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index db3ac6e8..eedf4132 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -14,7 +14,7 @@ from safetensors.torch import load_file from app.pipelines.base import Pipeline -from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker +from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker, is_lightning_model, is_turbo_model logger = logging.getLogger(__name__) @@ -139,7 +139,9 @@ def __init__(self, model_id: str): "call may be slow if 'SFAST' is enabled." ) - if deepcache_enabled: + if deepcache_enabled and not ( + is_lightning_model(model_id) or is_turbo_model(model_id) + ): logger.info( "TextToImagePipeline will be optimized with DeepCache for %s", model_id, @@ -147,6 +149,12 @@ def __init__(self, model_id: str): from app.pipelines.optim.deepcache import enable_deepcache self.ldm = enable_deepcache(self.ldm) + elif deepcache_enabled: + logger.warning( + "DeepCache is not supported for Lightning or Turbo models. " + "TextToImagePipeline will NOT be optimized with DeepCache for %s", + model_id, + ) safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() self._safety_checker = SafetyChecker(device=safety_checker_device) diff --git a/runner/app/pipelines/util.py b/runner/app/pipelines/util.py index 2a73b798..584d788c 100644 --- a/runner/app/pipelines/util.py +++ b/runner/app/pipelines/util.py @@ -8,6 +8,7 @@ from transformers import CLIPFeatureExtractor from typing import Optional import logging +import re logger = logging.getLogger(__name__) @@ -51,6 +52,30 @@ def validate_torch_device(device_name: str) -> bool: return False +def is_lightning_model(model_id: str) -> bool: + """Checks if the model is a Lightning model. + + Args: + model_id: Model ID. + + Returns: + True if the model is a Lightning model, False otherwise. + """ + return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None + + +def is_turbo_model(model_id: str) -> bool: + """Checks if the model is a Turbo model. + + Args: + model_id: Model ID. + + Returns: + True if the model is a Turbo model, False otherwise. + """ + return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None + + class SafetyChecker: """Checks images for unsafe or inappropriate content using a pretrained model.