Skip to content

Commit

Permalink
Add suppport for i2vgen-xl im2vid
Browse files Browse the repository at this point in the history
  • Loading branch information
stronk-dev committed Apr 12, 2024
1 parent 3c297ca commit 40f757c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 24 deletions.
34 changes: 29 additions & 5 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir

from diffusers import StableVideoDiffusionPipeline
from diffusers import StableVideoDiffusionPipeline, I2VGenXLPipeline
from huggingface_hub import file_download
import torch
import PIL
Expand All @@ -15,6 +15,8 @@

logger = logging.getLogger(__name__)

I2VGEN_LIGHTNING_MODEL_ID = "ali-vilab/i2vgen-xl"
SVD_LIGHTNING_MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"

class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
Expand All @@ -37,7 +39,12 @@ def __init__(self, model_id: str):
kwargs["variant"] = "fp16"

self.model_id = model_id
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)

if I2VGEN_LIGHTNING_MODEL_ID in model_id:
self.ldm = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
else:
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.enable_vae_slicing()
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
Expand All @@ -50,9 +57,6 @@ def __init__(self, model_id: str):
self.ldm = compile_model(self.ldm)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 4

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand All @@ -64,6 +68,26 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if SVD_LIGHTNING_MODEL_ID in self.model_id:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 4
if "prompt" in kwargs:
del kwargs["prompt"]
elif I2VGEN_LIGHTNING_MODEL_ID in self.model_id:
kwargs["num_frames"] = 18
kwargs["num_inference_steps"] = 50
if "fps" in kwargs:
del kwargs["fps"]
if "motion_bucket_id" in kwargs:
del kwargs["motion_bucket_id"]
if "noise_aug_strength" in kwargs:
del kwargs["noise_aug_strength"]
prompt = ""
if "prompt" in kwargs:
prompt = kwargs["prompt"]
del kwargs["prompt"]
return self.ldm(prompt, image, **kwargs).frames

return self.ldm(image, **kwargs).frames

def __str__(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
async def image_to_video(
image: Annotated[UploadFile, File()],
prompt: Annotated[str, Form()] = "",
model_id: Annotated[str, Form()] = "",
height: Annotated[int, Form()] = 576,
width: Annotated[int, Form()] = 1024,
Expand Down Expand Up @@ -74,6 +75,7 @@ async def image_to_video(
batch_frames = pipeline(
image=Image.open(image.file).convert("RGB"),
height=height,
prompt=prompt,
width=width,
fps=fps,
motion_bucket_id=motion_bucket_id,
Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ else
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download video models
huggingface-cli download ali-vilab/i2vgen-xl --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand Down
5 changes: 5 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@
"title": "Model Id",
"default": ""
},
"prompt": {
"type": "string",
"title": "Prompt",
"default": ""
},
"height": {
"type": "integer",
"title": "Height",
Expand Down
39 changes: 20 additions & 19 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 40f757c

Please sign in to comment.