Skip to content

Commit

Permalink
runner: Support stable-fast using env var
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 6, 2024
1 parent 5c48959 commit 02592a5
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 3 deletions.
4 changes: 3 additions & 1 deletion runner/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ RUN pyenv install $PYTHON_VERSION && \
# Upgrade pip and install your desired packages
ARG PIP_VERSION=23.3.2
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools wheel && \
pip install --no-cache-dir torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install --no-cache-dir torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1

WORKDIR /app
COPY ./requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt

RUN pip install https://github.com/chengzeyi/stable-fast/releases/download/v1.0.3/stable_fast-1.0.3+torch211cu121-cp311-cp311-manylinux2014_x86_64.whl

# Most DL models are quite large in terms of memory, using workers is a HUGE
# slowdown because of the fork and GIL with python.
# Using multiple pods seems like a better default strategy.
Expand Down
6 changes: 5 additions & 1 deletion runner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Build Docker image

```
docker build -t runner .
docker build -t livepeer/ai-runner:latest .
```

## Download models
Expand All @@ -19,6 +19,10 @@ pip install "huggingface_hub[cli]"
./dl-checkpoints.sh
```

## Optimizations

- Set the environment variable `SFAST=true` to enable dynamic compilation with [stable-fast](https://github.com/chengzeyi/stable-fast) to speed up inference for diffusion pipelines (the initial requests will be slower because the model will be dynamically compiled then).

## Run text-to-image container

Run container:
Expand Down
10 changes: 10 additions & 0 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import PIL
from typing import List
import logging
import os

from PIL import ImageFile

Expand Down Expand Up @@ -34,6 +35,15 @@ def __init__(self, model_id: str):
self.ldm = AutoPipelineForImage2Image.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
logger.info(
"ImageToImagePipeline will be dynamicallly compiled with stable-fast for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down
10 changes: 10 additions & 0 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import PIL
from typing import List
import logging
import os

from PIL import ImageFile

Expand Down Expand Up @@ -34,6 +35,15 @@ def __init__(self, model_id: str):
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
logger.info(
"ImageToVideoPipeline will be dynamicallly compiled with stable-fast for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 8
Expand Down
29 changes: 29 additions & 0 deletions runner/app/pipelines/sfast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
import logging

logger = logging.getLogger(__name__)


def compile_model(model):
config = CompilationConfig.Default()

# xformers and Triton are suggested for achieving best performance.
# It might be slow for Triton to generate, compile and fine-tune kernels.
try:
import xformers

config.enable_xformers = True
except ImportError:
logger.info("xformers not installed, skip")
# NOTE:
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
# Disable Triton if you encounter this problem.
try:
import triton

config.enable_triton = True
except ImportError:
logger.info("Triton not installed, skip")

model = compile(model, config)
return model
10 changes: 10 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import PIL
from typing import List
import logging
import os

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +31,15 @@ def __init__(self, model_id: str):
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
logger.info(
"TextToImagePipeline will be dynamicallly compiled with stable-fast for %s",
model_id,
)
from app.pipelines.sfast import compile_model

self.ldm = compile_model(self.ldm)

def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
seed = kwargs.pop("seed", None)
if seed is not None:
Expand Down
4 changes: 3 additions & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ pydantic
Pillow
python-multipart
uvicorn
huggingface_hub
huggingface_hub
xformers==0.0.23
triton>=2.1.0

0 comments on commit 02592a5

Please sign in to comment.