Skip to content

Commit

Permalink
#4630: fixes for some tests and demo in stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-drazic committed Jan 11, 2024
1 parent dc74432 commit e0a98a6
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 89 deletions.
126 changes: 78 additions & 48 deletions models/experimental/stable_diffusion/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,30 @@
from tqdm.auto import tqdm
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, HeunDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
PNDMScheduler,
HeunDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm

from models.utility_functions import torch_to_tt_tensor, torch_to_tt_tensor_rm, tt_to_torch_tensor, comp_pcc, comp_allclose_and_pcc, Profiler, \
enable_persistent_kernel_cache, disable_persistent_kernel_cache
from models.utility_functions import (
torch_to_tt_tensor,
torch_to_tt_tensor_rm,
tt_to_torch_tensor,
comp_pcc,
comp_allclose_and_pcc,
Profiler,
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)

import tt_lib as ttl
from models.experimental.stable_diffusion.tt.unet_2d_condition import UNet2DConditionModel as tt_unet_condition
from models.experimental.stable_diffusion.tt.experimental_ops import UseDeviceConv


def constant_prop_time_embeddings(timesteps, sample, time_proj):
Expand All @@ -43,45 +58,49 @@ def save_image_and_latents(latents, iter, vae, pre_fix="", pre_fix2=""):

torch.save(_latents, f"{pre_fix}{pre_fix2}latents_{iter}.pt")

def guide(noise_pred, guidance_scale, t): # will return latents

def guide(noise_pred, guidance_scale, t): # will return latents
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred


def latent_expansion(latents, scheduler, t):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
return latent_model_input


def make_tt_unet(state_dict):
tt_unet = tt_unet_condition(sample_size = 64,
in_channels = 4,
out_channels = 4,
center_input_sample = False,
flip_sin_to_cos = True,
freq_shift = 0,
down_block_types = ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'],
mid_block_type = 'UNetMidBlock2DCrossAttn',
up_block_types = ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'],
only_cross_attention = False,
block_out_channels = [320, 640, 1280, 1280],
layers_per_block = 2,
downsample_padding = 1,
mid_block_scale_factor = 1,
act_fn = 'silu',
norm_num_groups = 32,
norm_eps = 1e-05,
cross_attention_dim = 768,
attention_head_dim = 8,
dual_cross_attention = False,
use_linear_projection = False,
class_embed_type = None,
num_class_embeds = None,
upcast_attention = False,
resnet_time_scale_shift = 'default',
state_dict=state_dict,
base_address="")
tt_unet = tt_unet_condition(
sample_size=64,
in_channels=4,
out_channels=4,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
mid_block_type="UNetMidBlock2DCrossAttn",
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
only_cross_attention=False,
block_out_channels=[320, 640, 1280, 1280],
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-05,
cross_attention_dim=768,
attention_head_dim=8,
dual_cross_attention=False,
use_linear_projection=False,
class_embed_type=None,
num_class_embeds=None,
upcast_attention=False,
resnet_time_scale_shift="default",
state_dict=state_dict,
base_address="",
)
return tt_unet


Expand All @@ -103,11 +122,15 @@ def demo():
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# 4. load the K-LMS scheduler with some fitting parameters.
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
tt_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#scheduler = HeunDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#scheduler = DPMSolverMultistepScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
scheduler = LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
tt_scheduler = LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
# scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# scheduler = HeunDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# scheduler = DPMSolverMultistepScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

torch_device = "cpu"
vae.to(torch_device)
Expand All @@ -121,24 +144,28 @@ def demo():
experiment_name = "mountain_fallback_nolatentupdate"
# prompt = ["a photo of an astronaut riding a horse on mars"]
# prompt = ["car"]
prompt = ["oil painting frame of Breathtaking mountain range with a clear river running through it, surrounded by tall trees and misty clouds, serene, peaceful, mountain landscape, high detail"]

height = 256 # default height of Stable Diffusion
width = 256 # default width of Stable Diffusion
num_inference_steps = 2 # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
prompt = [
"oil painting frame of Breathtaking mountain range with a clear river running through it, surrounded by tall trees and misty clouds, serene, peaceful, mountain landscape, high detail"
]

height = 256 # default height of Stable Diffusion
width = 256 # default width of Stable Diffusion
num_inference_steps = 2 # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = len(prompt)

## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model.
# Tokenizer and Text Encoder
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_input = tokenizer(
prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

#For classifier-free guidance, we need to do two forward passes: one with the conditioned input (text_embeddings),
# For classifier-free guidance, we need to do two forward passes: one with the conditioned input (text_embeddings),
# and another with the unconditional embeddings (uncond_embeddings).
# In practice, we can concatenate both into a single batch to avoid doing two forward passes.
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
Expand Down Expand Up @@ -204,6 +231,9 @@ def demo():
# perform guidance
noise_pred = guide(noise_pred, guidance_scale, t)
# compute the previous noisy sample x_t -> x_t-1
if UseDeviceConv.READY:
# force unpad noise_pred
noise_pred = noise_pred[:, :4, :, :]
tt_latents = tt_scheduler.step(noise_pred, t, tt_latents).prev_sample
save_image_and_latents(tt_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")
pcc_res[iter] = comp_allclose_and_pcc(latents_dict[iter], tt_latents)
Expand All @@ -214,7 +244,6 @@ def demo():
iter += 1
enable_persistent_kernel_cache()


latents = last_latents
for key, val in pcc_res.items():
logger.info(f"{key}, {val}")
Expand All @@ -232,13 +261,14 @@ def demo():

ttl.device.CloseDevice(device)

'''

"""
@article{patil2022stable,
author = {Patil, Suraj and Cuenca, Pedro and Lambert, Nathan and von Platen, Patrick},
title = {Stable Diffusion with :firecracker: Diffusers},
journal = {Hugging Face Blog},
year = {2022},
note = {[https://huggingface.co/blog/rlhf](https://huggingface.co/blog/stable_diffusion)},
}
'''
"""
demo()
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,14 @@ def test_perf(device, expected_inference_time, expected_compile_time):
comments = f"image size: {height}x{width} - v1.4"

prep_perf_report(
"batched_stable_diffusion",
BATCH_SIZE,
first_iter_time,
second_iter_time,
comments,
cpu_time,
model_name="batched_stable_diffusion",
batch_size=BATCH_SIZE,
inference_and_compile_time=first_iter_time,
inference_time=second_iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comments,
inference_time_cpu=cpu_time,
)
logger.info(f"Batched Stable Diffusion {comments} inference time: {second_iter_time}")
logger.info(f"Batched Stable Diffusion {comments} compile time: {compile_time}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def test_batched_stable_diffusion(device):
noise_pred = guide(noise_pred, guidance_scale, t)

# compute the previous noisy sample x_t -> x_t-1
if UseDeviceConv.READY:
# force unpad noise_pred
noise_pred = noise_pred[:, :4, :, :]
tt_latents = tt_scheduler.step(noise_pred, t, tt_latents).prev_sample

# We need only one iteration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,26 @@


@pytest.mark.parametrize("index", [1]) # FIXME: failing 0, 2 with L1 error.
def test_run_cross_attn_down_block_real_input_inference(
device, index, model_location_generator
):
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32
)
def test_run_cross_attn_down_block_real_input_inference(device, index, model_location_generator):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
unet = pipe.unet
unet.eval()
state_dict = unet.state_dict()

dir_path = model_location_generator("tensor_files", model_subdir="StableDiffusion")
attr_path = f"{dir_path}/CrossAttnDownBlock2D_inp__attr__block_{index}.pt"
attention_mask_path = (
f"{dir_path}/CrossAttnDownBlock2D_inp__attention_mask__block_{index}.pt"
)
cross_attn_kwargs_path = (
f"{dir_path}/CrossAttnDownBlock2D_inp__cross_attention_kwargs__block_{index}.pt"
)
attention_mask_path = f"{dir_path}/CrossAttnDownBlock2D_inp__attention_mask__block_{index}.pt"
cross_attn_kwargs_path = f"{dir_path}/CrossAttnDownBlock2D_inp__cross_attention_kwargs__block_{index}.pt"
emb_path = f"{dir_path}/CrossAttnDownBlock2D_inp__emb__block_{index}.pt"
encoder_hidden_states_path = (
f"{dir_path}/CrossAttnDownBlock2D_inp__encoder_hidden_states__block_{index}.pt"
)
encoder_hidden_states_path = f"{dir_path}/CrossAttnDownBlock2D_inp__encoder_hidden_states__block_{index}.pt"
sample_path = f"{dir_path}/CrossAttnDownBlock2D_inp__sample__block_{index}.pt"

map_location = torch.device("cpu")
sample = torch.load(sample_path, map_location=map_location)
emb = torch.load(emb_path, map_location=map_location)
encoder_hidden_states = torch.load(
encoder_hidden_states_path, map_location=map_location
)
encoder_hidden_states = torch.load(encoder_hidden_states_path, map_location=map_location)
attention_mask = torch.load(attention_mask_path, map_location=map_location)
cross_attention_kwargs = torch.load(
cross_attn_kwargs_path, map_location=map_location
)
cross_attention_kwargs = torch.load(cross_attn_kwargs_path, map_location=map_location)

kwargs = torch.load(attr_path)
base_address = f"down_block.{index}"
Expand All @@ -75,12 +61,8 @@ def test_run_cross_attn_down_block_real_input_inference(
)

tt_sample = torch_to_tt_tensor_rm(sample, device, put_on_device=False)
tt_emb = torch_to_tt_tensor_rm(
emb.unsqueeze(0).unsqueeze(0), device, put_on_device=False
)
tt_encoder_hidden_states = torch_to_tt_tensor_rm(
encoder_hidden_states.unsqueeze(0), device, put_on_device=False
)
tt_emb = torch_to_tt_tensor_rm(emb.unsqueeze(0).unsqueeze(0), device, put_on_device=False)
tt_encoder_hidden_states = torch_to_tt_tensor_rm(encoder_hidden_states.unsqueeze(0), device, put_on_device=False)

tt_cross_attn_down_block = TtCrossAttnDownBlock2D(
**kwargs, state_dict=state_dict, base_address=f"down_blocks.{index}"
Expand Down Expand Up @@ -109,9 +91,7 @@ def test_run_cross_attn_down_block_real_input_inference(

def test_run_cross_attn_down_block_inference(device):
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32
)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
unet = pipe.unet
unet.eval()
state_dict = unet.state_dict()
Expand Down Expand Up @@ -184,9 +164,7 @@ def test_run_cross_attn_down_block_inference(device):

tt_sample = torch_to_tt_tensor_rm(sample, device, put_on_device=False)
tt_emb = torch_to_tt_tensor_rm(emb, device, put_on_device=False)
tt_encoder_hidden_states = torch_to_tt_tensor_rm(
encoder_hidden_states, device, put_on_device=False
)
tt_encoder_hidden_states = torch_to_tt_tensor_rm(encoder_hidden_states, device, put_on_device=False)

tt_output, list_out = tt_cross_attn_down_block(
tt_sample,
Expand All @@ -195,7 +173,7 @@ def test_run_cross_attn_down_block_inference(device):
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
ttl.device.Synchronize()
ttl.device.Synchronize(device)
tt_output = tt_to_torch_tensor(tt_output)

passing = comp_pcc(torch_output, tt_output, pcc=0.95)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_run_basic_transformer_inference(device):
logger.info(f"PASSED {passing[1]}")


def test_run_transformer_inference():
def test_run_transformer_inference(device):
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
unet = pipe.unet
Expand Down

0 comments on commit e0a98a6

Please sign in to comment.