Skip to content

Commit

Permalink
Sfast optimization (#8)
Browse files Browse the repository at this point in the history
* update to the frame interpolation pipeline, there is some minor issue with creating go api bindings because of openapi json sceme having a null option.

* minor changes to requirements

* update to requrements to fetch from --index-url

* simple patch to solve the go api bindings issue

* checking if it works in my system

* test-examples for frame-interpolation

* update to sfast optimization to i2i and t2i and upscale pipelines

* changes to extra files

* added git ignore to the files to remove unnecessary files

* files not removed checking again

* still in test phase

* test-test

* Update .gitignore

* Delete runner/app/tests-examples directory

* update to directory reader as it now reads almost any naming convention

---------

Co-authored-by: Jason Stone <[email protected]>
  • Loading branch information
JJassonn69 and jjassonn authored Jul 29, 2024
1 parent 672f5fd commit a9683c8
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
30 changes: 25 additions & 5 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from enum import Enum
import time
from typing import List, Optional, Tuple

import PIL
Expand Down Expand Up @@ -30,6 +31,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
Expand Down Expand Up @@ -142,11 +144,29 @@ def __init__(self, model_id: str):
# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"ImageToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
warmup_kwargs = {
"prompt":"A warmed up pipeline is a happy pipeline a short poem by ricksta",
"image": PIL.Image.new("RGB", (576, 1024)),
"strength": 0.8,
"negative_prompt": "No blurry or weird artifacts",
"num_images_per_prompt":4,
}

logger.info("Warming up ImageToImagePipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
logger.error(f"ImageToImagePipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module provides a function to enable StableFast optimization for the pipeline.
For more information, see the DeepCache project on GitHub: https://github.com/chengzeyi/stable-fast
For more information, see the Stable Fast project on GitHub: https://github.com/chengzeyi/stable-fast
"""

import logging
Expand Down
33 changes: 26 additions & 7 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from enum import Enum
import time
from typing import List, Optional, Tuple

import PIL
Expand All @@ -26,6 +27,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
Expand Down Expand Up @@ -151,14 +153,31 @@ def __init__(self, model_id: str):

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for TextToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"TextToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"prompt": "A happy pipe in the line looking at the wall with words sfast",
"num_images_per_prompt": 4,
"negative_prompt": "No blurry or weird artifacts",
}

logger.info("Warming up TextToImagePipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"TextToImagePipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand Down
30 changes: 25 additions & 5 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from typing import List, Optional, Tuple

import PIL
Expand All @@ -21,6 +22,7 @@

logger = logging.getLogger(__name__)

SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.

class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -68,11 +70,29 @@ def __init__(self, model_id: str):
# Warm-up the pipeline.
# TODO: Not yet supported for UpscalePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"UpscalePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"prompt": "Upscaling the pipeline with sfast enabled",
"image": PIL.Image.new("RGB", (576, 1024)),
}

logger.info("Warming up ImageToVideoPipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"ImageToVideoPipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
Expand Down
13 changes: 12 additions & 1 deletion runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,22 @@ def check_nsfw_images(
)
return images, has_nsfw_concept


def natural_sort_key(s):
"""
Sort in a natural order, separating strings into a list of strings and integers.
This handles leading zeros and case insensitivity.
"""
return [
int(text) if text.isdigit() else text.lower()
for text in re.split(r'([0-9]+)', os.path.basename(s))
]

class DirectoryReader:
def __init__(self, dir: str):
self.paths = sorted(
glob.glob(os.path.join(dir, "*")),
key=lambda x: int(os.path.basename(x).split(".")[0]),
key=natural_sort_key
)
self.nb_frames = len(self.paths)
self.idx = 0
Expand Down

0 comments on commit a9683c8

Please sign in to comment.