Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WentianZhang-ML authored Aug 30, 2024
1 parent 20a8f92 commit 06e2118
Showing 1 changed file with 137 additions and 54 deletions.
191 changes: 137 additions & 54 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import argparse
import os

import torch

from tgate import TgateSDXLLoader,TgatePixArtLoader,TgateSDLoader,TgateSDDeepCacheLoader,TgateSDXLDeepCacheLoader
from diffusers import PixArtAlphaPipeline,StableDiffusionXLPipeline,StableDiffusionPipeline
from tgate import TgateSDXLLoader, TgateSDXLDeepCacheLoader, TgatePixArtAlphaLoader, TgatePixArtSigmaLoader, TgateSVDLoader
from diffusers import StableDiffusionXLPipeline, PixArtAlphaPipeline, PixArtSigmaPipeline, StableVideoDiffusionPipeline
from diffusers import UNet2DConditionModel, LCMScheduler
from diffusers import DPMSolverMultistepScheduler
from diffusers.utils import load_image, export_to_video

def parse_args():
parser = argparse.ArgumentParser(description="Simple example of TGATE.")
parser = argparse.ArgumentParser(description="Simple example of TGATE V2.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="the input prompts",
)
parser.add_argument(
"--image",
type=str,
default=None,
help="the dir of input image to generate video",
)
parser.add_argument(
"--saved_path",
type=str,
Expand All @@ -27,14 +33,32 @@ def parse_args():
"--model",
type=str,
default='pixart',
help="[pixart,sd_xl,sd_2.1,sd_1.5,lcm_sdxl,lcm_pixart]",
help="[pixart_alpha,pixart_sigma,sdxl,lcm_sdxl,lcm_pixart_alpha,svd]",
)
parser.add_argument(
"--gate_step",
type=int,
default=10,
help="When re-using the cross-attention",
)
parser.add_argument(
'--sp_interval',
type=int,
default=5,
help="The time-step interval to cache self attention before gate_step (Semantics-Planning Phase).",
)
parser.add_argument(
'--fi_interval',
type=int,
default=1,
help="The time-step interval to cache self attention after gate_step (Fidelity-Improving Phase).",
)
parser.add_argument(
'--warm_up',
type=int,
default=2,
help="The time step to warm up the model inference",
)
parser.add_argument(
"--inference_step",
type=int,
Expand All @@ -55,60 +79,37 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
os.makedirs(args.saved_path, exist_ok=True)
saved_path = os.path.join(args.saved_path, 'test.png')
if args.model in ['sd_2.1', 'sd_1.5']:
if args.model == 'sd_1.5':
repo_id = "runwayml/stable-diffusion-v1-5"
elif args.model == 'sd_2.1':
repo_id = "stabilityai/stable-diffusion-2-1"

pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if args.deepcache:
pipe = TgateSDDeepCacheLoader(pipe,cache_interval=3,cache_branch_id=0)
else:
pipe = TgateSDLoader(pipe)
pipe = pipe.to("cuda")

image = pipe.tgate(args.prompt,
num_inference_steps=args.inference_step,
guidance_scale=7.5,
gate_step=args.gate_step,
).images[0]

if args.prompt:
saved_path = os.path.join(args.saved_path, 'test.png')
elif args.image:
saved_path = os.path.join(args.saved_path, 'test.mp4')

elif args.model == 'sd_xl':
pipeline_text2image = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
if args.model == 'sdxl':
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
)

if args.deepcache:
pipeline_text2image = TgateSDXLDeepCacheLoader(pipeline_text2image,cache_interval=3,cache_branch_id=0)
pipe = TgateSDXLDeepCacheLoader(
pipe,
cache_interval=3,
cache_branch_id=0
)
else:
pipeline_text2image = TgateSDXLLoader(pipeline_text2image)
pipeline_text2image.scheduler = DPMSolverMultistepScheduler.from_config(pipeline_text2image.scheduler.config)
pipeline_text2image = pipeline_text2image.to("cuda")

image = pipeline_text2image.tgate(prompt=args.prompt,
gate_step=args.gate_step,
num_inference_steps=args.inference_step).images[0]
elif args.model == 'pixart':
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe = TgatePixArtLoader(pipe).to("cuda")
image = pipe.tgate(args.prompt,
gate_step=args.gate_step,
num_inference_steps=args.inference_step).images[0]

pipe = TgateSDXLLoader(pipe)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

elif args.model == 'lcm_pixart':
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-LCM-XL-2-1024-MS", torch_dtype=torch.float16)
pipe = TgatePixArtLoader(pipe,lcm=True).to("cuda")
image = pipe.tgate(
args.prompt,
prompt=args.prompt,
gate_step=args.gate_step,
sp_interval=args.sp_interval if not args.deepcache else 1,
fi_interval=args.fi_interval,
warm_up=args.warm_up if not args.deepcache else 0,
num_inference_steps=args.inference_step,
guidance_scale=0.).images[0]

).images[0]
image.save(saved_path)

elif args.model == 'lcm_sdxl':
unet = UNet2DConditionModel.from_pretrained(
Expand All @@ -117,17 +118,99 @@ def parse_args():
variant="fp16",
)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16, variant="fp16",
"stabilityai/stable-diffusion-xl-base-1.0",
unet=unet,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe = TgateSDXLLoader(pipe,lcm=True).to("cuda")
pipe = TgateSDXLLoader(pipe).to("cuda")

image = pipe.tgate(
prompt=args.prompt,
gate_step=args.gate_step,
sp_interval=1,
fi_interval=args.fi_interval,
warm_up=0,
num_inference_steps=args.inference_step,
lcm=True,
).images[0]
image.save(saved_path)

elif args.model == 'pixart_alpha':
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
torch_dtype=torch.float16,
)
pipe = TgatePixArtAlphaLoader(pipe).to("cuda")

image = pipe.tgate(
prompt=args.prompt,
gate_step=args.gate_step,
sp_interval=args.sp_interval,
fi_interval=args.fi_interval,
warm_up=args.warm_up,
num_inference_steps=args.inference_step,
).images[0]
image.save(saved_path)

elif args.model == 'lcm_pixart':
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-LCM-XL-2-1024-MS",
torch_dtype=torch.float16,
)
pipe = TgatePixArtAlphaLoader(pipe).to("cuda")

image = pipe.tgate(
args.prompt,
gate_step=args.gate_step,
sp_interval=1,
fi_interval=args.fi_interval,
warm_up=0,
num_inference_steps=args.inference_step,
lcm=True,
guidance_scale=0.,
).images[0]
image.save(saved_path)

elif args.model == 'pixart_sigma':
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
torch_dtype=torch.float16,
)
pipe = TgatePixArtSigmaLoader(pipe).to("cuda")

image = pipe.tgate(
prompt=args.prompt,
gate_step=args.gate_step,
sp_interval=args.sp_interval,
fi_interval=args.fi_interval,
warm_up=args.warm_up,
num_inference_steps=args.inference_step,
).images[0]
image.save(saved_path)

elif args.model == 'svd':
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16,
variant="fp16",
)
pipe = TgateSVDLoader(pipe).to("cuda")

image = load_image(args.image)

frames = pipe.tgate(
image,
gate_step=args.gate_step,
num_inference_steps=args.inference_step,
warm_up=args.warm_up,
sp_interval=args.sp_interval,
fi_interval=args.fi_interval,
num_frames=25,
decode_chunk_size=8,
).frames[0]
export_to_video(frames, saved_path, fps=7)

else:
raise Exception('Please sepcify the model name!')
image.save(saved_path)

0 comments on commit 06e2118

Please sign in to comment.