Skip to content

Commit

Permalink
feat(runner): disable DEEPCACHE for lightning/turbo models (livepeer#93)
Browse files Browse the repository at this point in the history
* feat(runner): disable DEEPCACHE for lightning/turbo models

This commit ensures that people can not use the deepcache optimization
with lightning and turbo models. As explained in
livepeer#82 (comment)
this optimization does not offer any speedup for these models while it
does reduce image quality.

* fix: ensure correct lighthing/turbo matching

This commit ensures that Lightning/Turbo model names with both the `-`
and `_` are matched when checking whether the orchestrator loaded a
Lightning/Turbo model.
  • Loading branch information
rickstaa authored and eliteprox committed Jul 26, 2024
1 parent d824026 commit 5d02d5a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
20 changes: 17 additions & 3 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker
from app.pipelines.util import (
get_torch_device,
get_model_dir,
SafetyChecker,
is_lightning_model,
is_turbo_model,
)

from diffusers import (
AutoPipelineForImage2Image,
Expand Down Expand Up @@ -119,14 +125,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
"ImageToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"ImageToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
12 changes: 10 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker, is_lightning_model, is_turbo_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,14 +139,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
25 changes: 25 additions & 0 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import CLIPFeatureExtractor
from typing import Optional
import logging
import re

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,6 +52,30 @@ def validate_torch_device(device_name: str) -> bool:
return False


def is_lightning_model(model_id: str) -> bool:
"""Checks if the model is a Lightning model.
Args:
model_id: Model ID.
Returns:
True if the model is a Lightning model, False otherwise.
"""
return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None


def is_turbo_model(model_id: str) -> bool:
"""Checks if the model is a Turbo model.
Args:
model_id: Model ID.
Returns:
True if the model is a Turbo model, False otherwise.
"""
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Expand Down

0 comments on commit 5d02d5a

Please sign in to comment.