Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(runner): disable DEEPCACHE for lightning/turbo models #93

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker
from app.pipelines.util import (
get_torch_device,
get_model_dir,
SafetyChecker,
is_lightning_model,
is_turbo_model,
)

from diffusers import (
AutoPipelineForImage2Image,
Expand Down Expand Up @@ -119,14 +125,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
"ImageToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"ImageToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
12 changes: 10 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker, is_lightning_model, is_turbo_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,14 +139,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
25 changes: 25 additions & 0 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import CLIPFeatureExtractor
from typing import Optional
import logging
import re

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,6 +52,30 @@ def validate_torch_device(device_name: str) -> bool:
return False


def is_lightning_model(model_id: str) -> bool:
"""Checks if the model is a Lightning model.

Args:
model_id: Model ID.

Returns:
True if the model is a Lightning model, False otherwise.
"""
return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None


def is_turbo_model(model_id: str) -> bool:
"""Checks if the model is a Turbo model.

Args:
model_id: Model ID.

Returns:
True if the model is a Turbo model, False otherwise.
"""
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.

Expand Down