From 40f757c1c676f8c9b27aaa46223a75ecef713fee Mon Sep 17 00:00:00 2001 From: Marco van Dijk Date: Sat, 13 Apr 2024 00:32:30 +0200 Subject: [PATCH] Add suppport for i2vgen-xl im2vid --- runner/app/pipelines/image_to_video.py | 34 ++++++++++++++++++---- runner/app/routes/image_to_video.py | 2 ++ runner/dl_checkpoints.sh | 3 ++ runner/openapi.json | 5 ++++ worker/runner.gen.go | 39 +++++++++++++------------- 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 5967a672..8b4afbdc 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -1,7 +1,7 @@ from app.pipelines.base import Pipeline from app.pipelines.util import get_torch_device, get_model_dir -from diffusers import StableVideoDiffusionPipeline +from diffusers import StableVideoDiffusionPipeline, I2VGenXLPipeline from huggingface_hub import file_download import torch import PIL @@ -15,6 +15,8 @@ logger = logging.getLogger(__name__) +I2VGEN_LIGHTNING_MODEL_ID = "ali-vilab/i2vgen-xl" +SVD_LIGHTNING_MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" class ImageToVideoPipeline(Pipeline): def __init__(self, model_id: str): @@ -37,7 +39,12 @@ def __init__(self, model_id: str): kwargs["variant"] = "fp16" self.model_id = model_id - self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs) + + if I2VGEN_LIGHTNING_MODEL_ID in model_id: + self.ldm = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16") + else: + self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs) + self.ldm.enable_vae_slicing() self.ldm.to(get_torch_device()) if os.environ.get("SFAST"): @@ -50,9 +57,6 @@ def __init__(self, model_id: str): self.ldm = compile_model(self.ldm) def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: - if "decode_chunk_size" not in kwargs: - kwargs["decode_chunk_size"] = 4 - seed = kwargs.pop("seed", None) if seed is not None: if isinstance(seed, int): @@ -64,6 +68,26 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: torch.Generator(get_torch_device()).manual_seed(s) for s in seed ] + if SVD_LIGHTNING_MODEL_ID in self.model_id: + if "decode_chunk_size" not in kwargs: + kwargs["decode_chunk_size"] = 4 + if "prompt" in kwargs: + del kwargs["prompt"] + elif I2VGEN_LIGHTNING_MODEL_ID in self.model_id: + kwargs["num_frames"] = 18 + kwargs["num_inference_steps"] = 50 + if "fps" in kwargs: + del kwargs["fps"] + if "motion_bucket_id" in kwargs: + del kwargs["motion_bucket_id"] + if "noise_aug_strength" in kwargs: + del kwargs["noise_aug_strength"] + prompt = "" + if "prompt" in kwargs: + prompt = kwargs["prompt"] + del kwargs["prompt"] + return self.ldm(prompt, image, **kwargs).frames + return self.ldm(image, **kwargs).frames def __str__(self) -> str: diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 66544d52..181b0de5 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -32,6 +32,7 @@ ) async def image_to_video( image: Annotated[UploadFile, File()], + prompt: Annotated[str, Form()] = "", model_id: Annotated[str, Form()] = "", height: Annotated[int, Form()] = 576, width: Annotated[int, Form()] = 1024, @@ -74,6 +75,7 @@ async def image_to_video( batch_frames = pipeline( image=Image.open(image.file).convert("RGB"), height=height, + prompt=prompt, width=width, fps=fps, motion_bucket_id=motion_bucket_id, diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 33751bad..76254236 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -80,6 +80,9 @@ else # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + # Download video models + huggingface-cli download ali-vilab/i2vgen-xl --include "*.fp16.safetensors" "*.json" --cache-dir models + # Download image-to-video models (token-gated). printf "\nDownloading token-gated models...\n" check_hf_auth diff --git a/runner/openapi.json b/runner/openapi.json index 57dcc4a0..db467aeb 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -289,6 +289,11 @@ "title": "Model Id", "default": "" }, + "prompt": { + "type": "string", + "title": "Prompt", + "default": "" + }, "height": { "type": "integer", "title": "Height", diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 9e2e74b3..6244b155 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -51,6 +51,7 @@ type BodyImageToVideoImageToVideoPost struct { ModelId *string `json:"model_id,omitempty"` MotionBucketId *int `json:"motion_bucket_id,omitempty"` NoiseAugStrength *float32 `json:"noise_aug_strength,omitempty"` + Prompt *string `json:"prompt,omitempty"` Seed *int `json:"seed,omitempty"` Width *int `json:"width,omitempty"` } @@ -1077,25 +1078,25 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xXS2/jNhD+KwTbozd23E1T+Jb0tUabbhC720MQGIw0lrkrkSw5TNcI/N8LkrZEvSqn", - "yKZAkZNew5lvZr556JEmslBSgEBDZ4/UJBsomL+9uJ7/qLXU7l5pqUAjB/+lMJm7IMcc6IxemYyOKG6V", - "ezCoucjobjeiGv60XENKZ7f+yN2oPFLqLs/J+4+QIN2N6KVMtytesAxWKPc3jUclDbZhZZanTCSwMglz", - "Vh5pCmtmc6Sz85OzyvjPezmy8HIlBGGLe9AOgrfiFKylLhjSGb3nguktrZTMvUjL7REtZAr5iqc1+zQ6", - "eeUEyDztOiwgY8gfYKW0LBT26vhtL0eug1yXKluEaJmVAt2l8DTSZwviPTLkGnRLKxcIWQhNpedwth+C", - "AUhjyYV77lJqUIPIcFODNzn5rgK4OEi0stUgmjqgCTmMOHcsrwYp+cBTkM3HbkqulanzsILzkzKdsdgA", - "zzb1RJ2df1udexe+dx39z2hbSORSrO5t8gmwqeR0eh5rcZLk0kvWtEV+CMkNrJjNVj3EmEwj6jphcmEz", - "0s+RJ1DxL542zJ1Opm8rc3/47+2TDRoOsK+fQh3se7dcXvd04hSQ8dzdfa1hTWf0q3HVz8f7Zj4uu20T", - "5f54BLOy1QPkA8t5ylwSByFxhMIMYWvq21VYfgiaSiBMa7b1PsRomwq6cAPLcfP9BpJPbbwGGdp6ldL3", - "v9C49XiBrglX1WRloMO+L7obMEoKA20EoUsfHbErSDmL4xQad1ecWow0ca7rsDpwB0vtiB1bS1bnsdzv", - "Oh/cE6yX8RYipAFIB8IlfMal9I5cM81C8L7UVlB15iN68esa8PQ1oOy9T2y2ezARYdq86CDPYCvLZVKr", - "Sia279d0dvvY8vGxBfEuKtBfZeLNtEp01FqlwZieAR1eVKIeM1m6t0NF5fwIpvaSUaSOaJ8f3HTqb19r", - "zYpG+3piH2vEpNyQguKBvrY3H7tUw9tyyDMysZrjduGgBOxulFwC06DL3yB36D68KpVsEBXdOR1crGWo", - "I5Nornx+Z/RCEKZUzkPCCUqirSAXc6K4gpyL4M+BF/wBFIB232+sEN7QA2gTdE1OTk8mLiBSgWCK0xn9", - "xr8aUcVw42GPN370+EYHvh5darzxeVpOJupCFuLhT00nE3dJpEAQ/lQEevzROPOHf8GhNMazzwemHpCF", - "TRIwZm1zUqbEp8AWhVtNS4ju5dh3qjco35Sr7GGvrrvlK3tf4DTwAQy6HavhV2Fz5IppHLud+E3KkB3v", - "2rF/DLs6J1Fb2H3BiNfn9rExH9G3z5n1ck/ssH/JUnITUuLtTqfPare1MrYRVCKkXCvPXsr9uUDQguVk", - "AfoBNKl270Pf8TMk7ji3d7u7uCZ8islShmncqA3/tzBYG74LvlRt9P/PvHBt1Hv/a238n2sjMNzXBsJn", - "PGJsRGvhP1bGv3e+vXi+DofXAnjeAnAci2fDbvd3AAAA//92bsyPxhcAAA==", + "H4sIAAAAAAAC/+xX227jNhN+FYL/f+nEjrtpCt8lPa3RphvE7vYiCAxGGsvclUiWHKZrBH73gqQtUafK", + "22ZToMiVTsOZb2a+OeiJJrJQUoBAQ2dP1CQbKJi/vbyZf6+11O5eaalAIwf/pTCZuyDHHOiMXpuMjihu", + "lXswqLnI6G43ohp+t1xDSmd3/sj9qDxS6i7PyYcPkCDdjeiVTLcrXrAMVij3N41HJQ22YWWWp0wksDIJ", + "c1aeaAprZnOks4vT88r4j3s5svByJQRhiwfQDoK34hSspS4Y0hl94ILpLa2UzL1Iy+0RLWQK+YqnNfs0", + "OnntBMg87TosIGPIH2GltCwU9ur4ZS9HboJclypbhGiZlQLdpfAs0mcL4j0y5AZ0SysXCFkITaXncLYf", + "ggFIY8mFe+5SalCDyHBTgzc5/aYCuDhItLLVIJo6oAk5jDh3LK8GKfnIU5DNx25KrpWp87CC84MynbHY", + "AM829USdX3xdnXsbvncd/ddoW0jkUqwebPIRsKnkbHoRa3GS5MpL1rRFfgjJDayYzVY9xJhMI+o6YXJp", + "M9LPkZi2nc49A4f/4GkD59lk+qYy8Zv/3j7Z4O8Abfu510Hbt8vlTU8LTwEZz93d/zWs6Yz+b1wNgvF+", + "CozLNt1EuT8ewaxs9QB5z3KeMpf9QUgcoTBD2Jr6dhWW74KmEgjTmm29DzHapoIu3MBy3Hy7geRjG69B", + "hrZe3vTdTzTuWV6gazRWxVwZ6LDvq/UWjJLCQBtBaO9HR+waUs7iOIWO3xWnFiNNnOs6rA7cwVI7YsfW", + "ktV5LPerzgcXDOtlvIUIaQDSgXAJn3ApvSM3TLMQvC+1TlQt/Ygm/ro//IPe+5nNdg8mIkybFx3kGWxl", + "uUxqVcnE9t2azu6eWj4+tSDeRwX6s0y8mVaJjlo7OBjTM9nDi0rUYyZL93aoqJwfwdReMorUEe3zvZtO", + "/e1rrVnRaF+f2ccaMSlXq6B4oK/tzccu1fC2HPKMTKzmuF04KAG7GyVXwDTo8v/JHXoIr0olG0RFd04H", + "F2sZ6sgkmiuf3xm9FIQplfOQcIKSaCvI5ZworiDnIvhz4AV/BAWg3fdbK4Q39AjaBF2T07PTiQuIVCCY", + "4nRGv/KvRlQx3HjY440fPb7Rga9HlxpvfJ6Wk4m6kIV4+FPTycRdEikQhD8VgR5/MM784SdyKI3x7POB", + "qQdkYZMEjFnbnJQp8SmwReF22hKiezn2neoE5Um5Ax8W8rpbvrL3BU4DH8Cg27EafhU2R66YxrFbpk9S", + "hux414791djVOYnawu4LRrw+t4+N+Yi+ec6sl3tih/0rlpLbkBJvdzp9VrutlbGNoBIh5Vp5/lLuzwWC", + "FiwnC9CPoEm1ex/6jp8hcce5u9/dxzXhU0yWMkzjRm34v4XB2vBd8KVqo/9/5oVro977X2vjv1wbgeG+", + "NhA+4RFjI1oL/7Iy/r7z7cXzdTi8FsDzFoDjWDwbdrs/AwAA//8s2LIu/xcAAA==", } // GetSwagger returns the content of the embedded swagger specification file