Skip to content

Commit

Permalink
perf: pre-warm ImageToVideo SFAST pipeline
Browse files Browse the repository at this point in the history
This commit ensures that the model is pre-traced when SFAST is enabled
for the ImageToVideo pipeline. Without this pre-tracing the first
request will be slower than a non SFAST call.
  • Loading branch information
rickstaa committed Apr 15, 2024
1 parent 74e31fc commit 2aae19a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
14 changes: 12 additions & 2 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,25 @@ def __init__(self, model_id: str):
model_id, **kwargs
).to(torch_device)

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"ImageToImagePipeline will be dynamicallly compiled with stable-fast for %s",
"ImageToImagePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"ImageToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down
42 changes: 40 additions & 2 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from typing import List
import logging
import os
import time

from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.


class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
Expand All @@ -40,17 +43,52 @@ def __init__(self, model_id: str):
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"ImageToVideoPipeline will be dynamicallly compiled with stable-fast for %s",
"ImageToVideoPipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# NOTE: Initial calls may be slow due to compilation. Subsequent calls will
# be faster.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"image": PIL.Image.new("RGB", (576, 1024)),
"height": 576,
"width": 1024,
"fps": 6,
"motion_bucket_id": 127,
"noise_aug_strength": 0.02,
"decode_chunk_size": 25,
}

logger.info("Warming up ImageToVideoPipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).frames
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"ImageToVideoPipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4

seed = kwargs.pop("seed", None)
Expand Down
14 changes: 12 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,25 @@ def __init__(self, model_id: str):
self.ldm.vae.decode, mode="max-autotune", fullgraph=True
)

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"TextToImagePipeline will be dynamicallly compiled with stable-fast for %s",
"TextToImagePipeline will be dynamically compiled with stable-fast for "
"%s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"TextToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down

0 comments on commit 2aae19a

Please sign in to comment.