Skip to content

Commit

Permalink
POC: quick POC for pre-tracing sfast
Browse files Browse the repository at this point in the history
This is a quick commit that enables pre-tracing for the SFAST variant of
the text-to-image and video-to-image pipelines.
> [!ATTENTION]
> DO NOT MERGE INTO THE MAIN BRANCH!!!!
  • Loading branch information
rickstaa committed Apr 3, 2024
1 parent 74e31fc commit 6f50000
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 3 deletions.
37 changes: 36 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from typing import List
import logging
import os
import time

from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

WARMUP_ITERATIONS = 3 # Warm-up calls count when SFAST is enabled.
WARMUP_BATCH_SIZE = 3 # Max batch size for warm-up calls when SFAST is enabled.


class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -49,6 +53,32 @@ def __init__(self, model_id: str):

self.ldm = compile_model(self.ldm)

# Retrieve default model params.
warmup_kwargs = {
"image": PIL.Image.new("RGB", (512, 512)),
"height": 512,
"width": 512,
}

# NOTE: Warmup pipeline.
# The initial calls will trigger compilation and might be very slow.
# After that, it should be very fast.
# FIXME: This will crash the pipeline if there is not enough VRAM available.
logger.info("Warming up pipeline...")
import time
for ii in range(WARMUP_ITERATIONS):
logger.info(f"Warmup iteration {ii + 1}...")
t = time.time()
try:
self.ldm(**warmup_kwargs).frames
except Exception as e:
logger.error(f"ImageToVideoPipeline warmup error: {e}")
logger.exception(e)
# FIXME: When cuda out of memory, we need to reload the full model before it works again :(. torch.cuda.clear_cache() does not work.
# continue
raise e
logger.info("Warmup iteration took %s seconds", time.time() - t)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 4
Expand All @@ -64,7 +94,12 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

return self.ldm(image, **kwargs).frames
t = time.time()
frames = self.ldm(image, **kwargs).frames
logger.info("TextToImagePipeline took %s seconds", time.time() - t)

return frames
# return self.ldm(image, **kwargs).frames

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
63 changes: 61 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from typing import List
import logging
import os
import time

logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

WARMUP_ITERATIONS = 3 # Warm-up calls count when SFAST is enabled.
WARMUP_BATCH_SIZE = 3 # Max batch size for warm-up calls when SFAST is enabled.

class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -117,9 +119,61 @@ def __init__(self, model_id: str):
model_id,
)
from app.pipelines.sfast import compile_model
from app.routes.text_to_image import TextToImageParams

self.ldm = compile_model(self.ldm)

# Retrieve default model params.
warmup_kwargs = TextToImageParams(
prompt="A warmed up pipeline is a happy pipeline"
)
if (
self.model_id == "stabilityai/sdxl-turbo"
or self.model_id == "stabilityai/sd-turbo"
):
# SD turbo models were trained without guidance_scale so
# it should be set to 0
warmup_kwargs.guidance_scale = 0.0

if "num_inference_steps" not in kwargs:
warmup_kwargs.num_inference_steps = 1
elif SDXL_LIGHTNING_MODEL_ID in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
warmup_kwargs.guidance_scale = 0.0

if "2step" in self.model_id:
warmup_kwargs.num_inference_steps = 2
elif "4step" in self.model_id:
warmup_kwargs.num_inference_steps = 4
elif "8step" in self.model_id:
warmup_kwargs.num_inference_steps = 8
else:
# Default to 2step
warmup_kwargs.num_inference_steps = 2

# NOTE: Warmup pipeline.
# The initial calls will trigger compilation and might be very slow.
# After that, it should be very fast.
# FIXME: This will crash the pipeline if there is not enough VRAM available.
logger.info("Warming up pipeline...")
import time
for batch in range(WARMUP_BATCH_SIZE):
warmup_kwargs.num_images_per_prompt = batch + 1
logger.info(f"Warmup with batch {batch + 1}...")
for ii in range(WARMUP_ITERATIONS):
logger.info(f"Warmup iteration {ii + 1}...")
t = time.time()
try:
self.ldm(**warmup_kwargs.model_dump()).images[0]
except Exception as e:
logger.error(f"TextToImagePipeline warmup error: {e}")
logger.exception(e)
# FIXME: When cuda out of memory, we need to reload the full model before it works again :(. torch.cuda.clear_cache() does not work.
# continue
raise e
logger.info("Warmup iteration took %s seconds", time.time() - t)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down Expand Up @@ -157,7 +211,12 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
# Default to 2step
kwargs["num_inference_steps"] = 2

return self.ldm(prompt, **kwargs).images
t = time.time()
images = self.ldm(prompt, **kwargs).images
logger.info("TextToImagePipeline took %s seconds", time.time() - t)

return images
# return self.ldm(prompt, **kwargs).images

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
3 changes: 3 additions & 0 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class TextToImageParams(BaseModel):
negative_prompt: str = ""
seed: int = None
num_images_per_prompt: int = 1
# Model specific parameters.
# These are not used by all models.
num_inference_steps: int = 1


responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}
Expand Down

0 comments on commit 6f50000

Please sign in to comment.