Skip to content

Commit

Permalink
Merge pull request #135 from nateraw/remove-autocast
Browse files Browse the repository at this point in the history
Remove autocast from pipeline
  • Loading branch information
nateraw authored Jan 6, 2023
2 parents 8c77e29 + c2f0423 commit def2c57
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 42 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
transformers>=4.21.0
diffusers==0.9.0
diffusers==0.11.1
scipy
fire
gradio
librosa
av<10.0.0
realesrgan==0.2.5.0
realesrgan==0.2.5.0
89 changes: 49 additions & 40 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def disable_attention_slicing(self):
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -371,6 +371,9 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -549,9 +552,9 @@ def __call__(
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
embeds_a = self.embed_text(prompt_a)
embeds_b = self.embed_text(prompt_b)

latents_a = self.init_noise(seed_a, noise_shape)
latents_b = self.init_noise(seed_b, noise_shape)
latents_dtype = embeds_a.dtype
latents_a = self.init_noise(seed_a, noise_shape, latents_dtype)
latents_b = self.init_noise(seed_b, noise_shape, latents_dtype)

batch_idx = 0
embeds_batch, noise_batch = None, None
Expand Down Expand Up @@ -581,15 +584,19 @@ def make_clip_frames(
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
eta: float = 0.0,
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
upsample: bool = False,
batch_size: int = 1,
image_file_ext: str = ".png",
T: np.ndarray = None,
skip: int = 0,
negative_prompt: str = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)

Expand All @@ -614,24 +621,23 @@ def make_clip_frames(

frame_index = skip
for _, embeds_batch, noise_batch in batch_generator:
with torch.autocast("cuda"):
outputs = self(
latents=noise_batch,
text_embeddings=embeds_batch,
height=height,
width=width,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy",
negative_prompt=negative_prompt,
)["images"]

for image in outputs:
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
image = image if not upsample else self.upsampler(image)
image.save(frame_filepath)
frame_index += 1
outputs = self(
latents=noise_batch,
text_embeddings=embeds_batch,
height=height,
width=width,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy",
negative_prompt=negative_prompt,
)["images"]

for image in outputs:
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
image = image if not upsample else self.upsampler(image)
image.save(frame_filepath)
frame_index += 1

def walk(
self,
Expand All @@ -645,8 +651,8 @@ def walk(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
height: Optional[int] = 512,
width: Optional[int] = 512,
height: Optional[int] = None,
width: Optional[int] = None,
upsample: Optional[bool] = False,
batch_size: Optional[int] = 1,
resume: Optional[bool] = False,
Expand Down Expand Up @@ -686,9 +692,9 @@ def walk(
eta (Optional[float], *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
height (Optional[int], *optional*, defaults to 512):
height (Optional[int], *optional*, defaults to None):
height of the images to generate.
width (Optional[int], *optional*, defaults to 512):
width (Optional[int], *optional*, defaults to None):
width of the images to generate.
upsample (Optional[bool], *optional*, defaults to False):
When True, upsamples images with realesrgan.
Expand Down Expand Up @@ -744,6 +750,9 @@ def walk(
Returns:
str: The resulting video filepath. This video includes all sub directories' video clips.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

output_path = Path(output_dir)

Expand Down Expand Up @@ -879,19 +888,18 @@ def walk(

def embed_text(self, text, negative_prompt=None):
"""Helper to embed some text"""
with torch.autocast("cuda"):
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed

def init_noise(self, seed, noise_shape):
def init_noise(self, seed, noise_shape, dtype):
"""Helper to initialize noise"""
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
if self.device.type == "mps":
Expand All @@ -905,6 +913,7 @@ def init_noise(self, seed, noise_shape):
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed),
dtype=dtype,
)
return noise

Expand Down

0 comments on commit def2c57

Please sign in to comment.