Skip to content

Commit

Permalink
perf: pre-warm ImageToVideo SFAST pipeline
Browse files Browse the repository at this point in the history
This commit ensures that the model is pretraced when SFAST is enabled
for the ImageToVideo pipeline. Without this pre-tracing the first
request will be slower than a non SFAST call.
  • Loading branch information
rickstaa committed Apr 15, 2024
1 parent 74e31fc commit 74fefad
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
14 changes: 12 additions & 2 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,25 @@ def __init__(self, model_id: str):
model_id, **kwargs
).to(torch_device)

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"ImageToImagePipeline will be dynamicallly compiled with stable-fast for %s",
"ImageToImagePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"ImageToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down
42 changes: 40 additions & 2 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from typing import List
import logging
import os
import time

from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.


class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
Expand All @@ -40,17 +43,52 @@ def __init__(self, model_id: str):
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"ImageToVideoPipeline will be dynamicallly compiled with stable-fast for %s",
"ImageToVideoPipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# NOTE: Initial calls may be slow due to compilation. Subsequent calls will
# be faster.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"image": PIL.Image.new("RGB", (576, 1024)),
"height": 576,
"width": 1024,
"fps": 6,
"motion_bucket_id": 127,
"noise_aug_strength": 0.02,
"decode_chunk_size": 25,
}

logger.info("Warming up ImageToVideoPipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).frames
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"ImageToVideoPipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4

seed = kwargs.pop("seed", None)
Expand Down
18 changes: 15 additions & 3 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from typing import List
import logging
import os
import time

logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

WARMUP_ITERATIONS = 3 # Warm-up calls count when SFAST is enabled.
WARMUP_BATCH_SIZE = 3 # Max batch size for warm-up calls when SFAST is enabled.

class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -111,15 +113,25 @@ def __init__(self, model_id: str):
self.ldm.vae.decode, mode="max-autotune", fullgraph=True
)

if os.environ.get("SFAST"):
if os.getenv("SFAST", "").strip().lower() == "true":
logger.info(
"TextToImagePipeline will be dynamicallly compiled with stable-fast for %s",
"TextToImagePipeline will be dynamically compiled with stable-fast for "
"%s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"TextToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down

0 comments on commit 74fefad

Please sign in to comment.