Skip to content

Commit

Permalink
#5773: Move SD model to demo folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Sudharsan-V authored and Maganuru Jayasurya committed May 30, 2024
1 parent e75540b commit f3d075d
Show file tree
Hide file tree
Showing 65 changed files with 316 additions and 270 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 t/s/u |
| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 t/s/u |
| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) | any | 8 | 270 | 340 | 400 |
| Stable Diffusion 1.4 512x512 | coming soon | 1 | | | |
| Stable Diffusion 1.4 512x512 (seconds for denoise) | | 1 | 114s | 0.2s | |

[3] - Generating the i'th token in a sequence while the kv_cache is filled with i-1 rows.

Expand Down
32 changes: 32 additions & 0 deletions models/demos/wormhole/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Stable_diffusion Model

## Introduction
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.

# Details
The entry point to functional_stable_diffusion model is UNet2DConditionModel in `models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_unet_2d_condition_model.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `CompVis/stable-diffusion-v1-4` version from huggingface as our reference.

# Inputs
Inputs by default are provided from `input_data.json`. If you wish to change the inputs, provide a different path to test_demo.We do not recommend modifying `input_data.json` file.

## How to Run

To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables.
For more information, refer [installation and build guide](https://github.com/tenstorrent/tt-metal/blob/main/INSTALLING.md).

Use `pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` to run the demo.

If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo`

Our second demo is designed to run poloclub/diffusiondb dataset, run this with `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb`.

If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps, use `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[<num_prompts>-<num_inference_steps>]`

Note: ttnn stable diffusion utilizes `PNDMScheduler` and requires `num_inference_steps to be greater than or equal to 4`. [Reference](https://arxiv.org/pdf/2202.09778)

# Metrics Interpretation
`FID Score (Fréchet Inception Distance)` evaluates the quality of generated images by measuring the similarity between their feature distributions and those of real images. A lower FID score indicates better similarity between generated and real images.
For more information, refer [FID Score](https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html).

`CLIP Score` measures the similarity between the generated images and the input prompts. Higher CLIP scores indicate better alignment between the generated images and the provided text prompts.
For more information, refer [CLIP Score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html).
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,21 @@
from loguru import logger
from tqdm.auto import tqdm
from datasets import load_dataset
import os

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from models.utility_functions import (
skip_for_grayskull,
)
from models.utility_functions import skip_for_grayskull
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from ttnn.model_preprocessing import preprocess_model_parameters
from models.experimental.functional_stable_diffusion.sd_pndm_scheduler import TtPNDMScheduler
from models.experimental.functional_stable_diffusion.custom_preprocessing import custom_preprocessor
from models.experimental.functional_stable_diffusion.tt2.ttnn_functional_unet_2d_condition_model import (
from models.demos.wormhole.stable_diffusion.sd_pndm_scheduler import TtPNDMScheduler
from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor
from models.demos.wormhole.stable_diffusion.tt2.ttnn_functional_unet_2d_condition_model import (
UNet2DConditionModel as UNet2D,
)

Expand Down Expand Up @@ -65,8 +62,6 @@ def tt_guide(noise_pred, guidance_scale): # will return latents
noise_pred.shape[3] - 1,
],
)

# noise_pred_uncond, noise_pred_text = ttnn.split(noise_pred, noise_pred.shape[0] // 2, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred

Expand Down Expand Up @@ -108,6 +103,9 @@ def preprocess_images(image_paths):
def run_demo_inference(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size=(256, 256)):
disable_persistent_kernel_cache()

assert (
num_inference_steps >= 4
), f"PNDMScheduler only supports num_inference_steps >= 4. Found num_inference_steps={num_inference_steps}"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

Expand Down Expand Up @@ -247,93 +245,49 @@ def run_demo_inference_diffusiondb(
device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size=(256, 256)
):
disable_persistent_kernel_cache()
device.enable_program_cache()

assert (
num_inference_steps >= 4
), f"PNDMScheduler only supports num_inference_steps >= 4. Found num_inference_steps={num_inference_steps}"
# 0. Load a sample prompt from the dataset
dataset = load_dataset("poloclub/diffusiondb", "2m_random_1k")
data_1k = dataset["train"]

height, width = image_size

torch_device = "cpu"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# 4. load the K-LMS scheduler with some fitting parameters.
ttnn_scheduler = TtPNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, device=device
)

text_encoder.to(torch_device)
unet.to(torch_device)

config = unet.config
parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
input_height = 64
input_width = 64
reader_patterns_cache = {} if height == 512 and width == 512 else None
model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)

guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = 1
for i in range(num_prompts):
experiment_name = f"diffusiondb_{i}__{height}x{width}"
input_prompt = [f"{data_1k['prompt'][i]}"]
logger.info(f"input_prompts: {input_prompt}")

# Initial random noise
latents = torch.randn(
(batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
generator=generator,
)
latents = latents.to(torch_device)
image = np.array(data_1k["image"][i])
ref_images = Image.fromarray(image)
ref_img_path = f"{experiment_name}_ref.png"
ref_images.save(ref_img_path)

ttnn_scheduler.set_timesteps(num_inference_steps)
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

latents = latents * ttnn_scheduler.init_noise_sigma
rand_latents = torch.tensor(latents)
rand_latents = ttnn.from_torch(rand_latents, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
ttnn_latent_model_input = ttnn.concat([rand_latents, rand_latents], dim=0)
_tlist = []
for t in ttnn_scheduler.timesteps:
_t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
_t = _t.unsqueeze(0).unsqueeze(0)
_t = _t.permute(2, 0, 1, 3) # pre-permute temb
_t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
_tlist.append(_t)
# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

time_step = ttnn_scheduler.timesteps.tolist()
# 4. load the K-LMS scheduler with some fitting parameters.
ttnn_scheduler = TtPNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, device=device
)

interactive = os.environ.get("INTERACTIVE_SD_DEMO", "0") == "1"
i = 0
while i < num_prompts:
ttnn_scheduler.set_timesteps(num_inference_steps)
if interactive:
print("Enter the input promt, or q to exit:")
input_prompt = [input()]
if input_prompt[0] == "q":
break
else:
input_prompt = [f"{data_1k['prompt'][i]}"]

image = np.array(data_1k["image"][i])
ref_images = Image.fromarray(image)
ref_img_path = f"{experiment_name}_ref.png"
ref_images.save(ref_img_path)
i = i + 1
torch_device = "cpu"
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)

experiment_name = f"diffusiondb_{i}__{height}x{width}"
logger.info(f"input_prompts: {input_prompt}")
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = len(input_prompt)

## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model.
# Tokenizer and Text Encoder
Expand All @@ -357,10 +311,44 @@ def run_demo_inference_diffusiondb(
ttnn_text_embeddings = ttnn.from_torch(
ttnn_text_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# Initial random noise
latents = torch.randn(
(batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
generator=generator,
)
latents = latents.to(torch_device)

ttnn_scheduler.set_timesteps(num_inference_steps)

latents = latents * ttnn_scheduler.init_noise_sigma
ttnn_latents = torch.tensor(latents)
ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

config = unet.config
parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
input_height = 64
input_width = 64
reader_patterns_cache = {} if height == 512 and width == 512 else None
# ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
ttnn_latent_model_input = ttnn.concat([ttnn_latents, ttnn_latents], dim=0)
_tlist = []
for t in ttnn_scheduler.timesteps:
_t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
_t = _t.unsqueeze(0).unsqueeze(0)
_t = _t.permute(2, 0, 1, 3) # pre-permute temb
_t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
_tlist.append(_t)

time_step = ttnn_scheduler.timesteps.tolist()

model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)
iter = 0
ttnn_latents = rand_latents
# # Denoising loop
for index in tqdm(range(len(time_step))):
for index in range(len(time_step)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
ttnn_latent_model_input = ttnn.concat([ttnn_latents, ttnn_latents], dim=0)
_t = _tlist[index]
Expand All @@ -377,12 +365,12 @@ def run_demo_inference_diffusiondb(
return_dict=True,
config=config,
)
print(f"Sample: {iter}")

# perform guidance
noise_pred = tt_guide(ttnn_output, guidance_scale)
ttnn_latents = ttnn_scheduler.step(noise_pred, t, ttnn_latents).prev_sample
if not interactive:
_save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")
_save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")

iter += 1
enable_persistent_kernel_cache()
Expand All @@ -401,15 +389,15 @@ def run_demo_inference_diffusiondb(
ttnn_output_path = f"{experiment_name}_ttnn.png"
pil_images.save(ttnn_output_path)

ref_paths = [ref_img_path, ref_img_path]
ttnn_paths = [ttnn_output_path, ttnn_output_path]

ref_images = preprocess_images(ref_paths)
ttnn_images = preprocess_images(ttnn_paths)
if not interactive:
ref_paths = [ref_img_path, ref_img_path]
ref_images = preprocess_images(ref_paths)

# Calculate FID scores
fid_score_ref_ttnn = calculate_fid_score(ref_images, ttnn_images)
logger.info(f"FID Score (Reference vs TTNN): {fid_score_ref_ttnn}")
# Calculate FID scores
fid_score_ref_ttnn = calculate_fid_score(ref_images, ttnn_images)
logger.info(f"FID Score (Reference vs TTNN): {fid_score_ref_ttnn}")

# calculate Clip score
clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
Expand All @@ -420,13 +408,14 @@ def run_demo_inference_diffusiondb(


@skip_for_grayskull()
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.parametrize(
"num_prompts",
((1),),
)
@pytest.mark.parametrize(
"num_inference_steps",
((2),),
((4),),
)
@pytest.mark.parametrize(
"image_size",
Expand All @@ -444,7 +433,7 @@ def test_demo(device, reset_seeds, input_path, num_prompts, num_inference_steps,
)
@pytest.mark.parametrize(
"num_inference_steps",
((30),),
((4),),
)
@pytest.mark.parametrize(
"image_size",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_cross_attention import cross_attention
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_feedforward import feedforward
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_cross_attention import cross_attention
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_feedforward import feedforward


def basic_transformer_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d


def cross_attention_down_block_2d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
import ttnn
from typing import Optional, Dict
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model


def torch_to_ttnn(input, device, layout=ttnn.TILE_LAYOUT):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import ttnn
import torch
from typing import Optional
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d


def downblock2d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from tt_lib.fallback_ops import fallback_ops
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import (
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_utility_functions import (
run_ttnn_conv_with_pre_and_post_tensor_formatting,
)
import math
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_geglu import geglu
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_geglu import geglu


def feedforward(config, hidden_states, parameters, device=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from typing import Optional, Dict
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import (
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_utility_functions import (
run_ttnn_conv_with_pre_and_post_tensor_formatting,
pre_process_input,
post_process_output,
Expand Down
Loading

0 comments on commit f3d075d

Please sign in to comment.