Skip to content

Commit

Permalink
feat(runner): disable DEEPCACHE for lightning/turbo models
Browse files Browse the repository at this point in the history
This commit ensures that people can not use the deepcache optimization
with lightning and turbo models. As explained in
#82 (comment)
this optimization does not offer any speedup for these models while it
does reduce image quality.
  • Loading branch information
rickstaa committed May 31, 2024
1 parent ace582c commit 45898e3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
20 changes: 17 additions & 3 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker
from app.pipelines.util import (
get_torch_device,
get_model_dir,
SafetyChecker,
is_lightning_model,
is_turbo_model,
)

from diffusers import (
AutoPipelineForImage2Image,
Expand Down Expand Up @@ -119,14 +125,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
"ImageToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"ImageToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
12 changes: 10 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker
from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker, is_lightning_model, is_turbo_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,14 +139,22 @@ def __init__(self, model_id: str):
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled:
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
Expand Down
22 changes: 22 additions & 0 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def validate_torch_device(device_name: str) -> bool:
return False


def is_lightning_model(model_id: str) -> bool:
"""Checks if the model is a Lightning model.
Args:
model_id: Model ID.
Returns:
True if the model is a Lightning model, False otherwise.
"""
return "-lightning" in model_id.lower()

def is_turbo_model(model_id: str) -> bool:
"""Checks if the model is a Turbo model.
Args:
model_id: Model ID.
Returns:
True if the model is a Turbo model, False otherwise.
"""
return "-turbo" in model_id.lower()

class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Expand Down

0 comments on commit 45898e3

Please sign in to comment.