Skip to content

Commit

Permalink
#5773: Move SD model to demo folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Sudharsan-V authored and Maganuru Jayasurya committed May 27, 2024
1 parent 9f2fd27 commit f96b82d
Show file tree
Hide file tree
Showing 62 changed files with 7,887 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 t/s/u |
| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 t/s/u |
| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) | any | 8 | 270 | 340 | 400 |
| Stable Diffusion 1.4 512x512 | coming soon | 1 | | | |
| Stable Diffusion 1.4 512x512 (seconds for denoise) | | 1 | 128s | 2.5s | |

[3] - Generating the i'th token in a sequence while the kv_cache is filled with i-1 rows.

Expand Down
32 changes: 32 additions & 0 deletions models/demos/wormhole/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Stable_diffusion Model

## Introduction
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.

# Details
The entry point to functional_stable_diffusion model is UNet2DConditionModel in `models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_unet_2d_condition_model.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `CompVis/stable-diffusion-v1-4` version from huggingface as our reference.

# Inputs
Inputs by default are provided from `input_data.json`. If you wish to change the inputs, provide a different path to test_demo.We do not recommend modifying `input_data.json` file.

## How to Run

To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables.
For more information, refer [installation and build guide](https://github.com/tenstorrent/tt-metal/blob/main/INSTALLING.md).

Use `pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` to run the demo.

If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo`

Our second demo is designed to run poloclub/diffusiondb dataset, run this with `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb`.

If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps, use `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[<num_prompts>-<num_inference_steps>]`

Note: ttnn stable diffusion utilizes `PNDMScheduler` and requires `num_inference_steps to be greater than or equal to 4`. [Reference](https://arxiv.org/pdf/2202.09778)

# Metrics Interpretation
`FID Score (Fréchet Inception Distance)` evaluates the quality of generated images by measuring the similarity between their feature distributions and those of real images. A lower FID score indicates better similarity between generated and real images.
For more information, refer [FID Score](https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html).

`CLIP Score` measures the similarity between the generated images and the input prompts. Higher CLIP scores indicate better alignment between the generated images and the provided text prompts.
For more information, refer [CLIP Score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html).
41 changes: 41 additions & 0 deletions models/demos/wormhole/stable_diffusion/custom_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
from torch import nn
import ttnn


def preprocess_groupnorm_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter


def preprocess_conv_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter


def custom_preprocessor(model, name):
parameters = {}
if isinstance(model, nn.GroupNorm):
parameters["weight"] = preprocess_groupnorm_parameter(model.weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_groupnorm_parameter(model.bias, dtype=ttnn.bfloat16)

if isinstance(model, nn.Conv2d):
weight = torch.permute(model.weight, (2, 3, 0, 1))
parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_conv_parameter(model.bias, dtype=ttnn.bfloat16)

if isinstance(model, (nn.Linear, nn.LayerNorm)):
weight = model.weight.T.contiguous()
while len(weight.shape) < 4:
weight = weight.unsqueeze(0)
parameters["weight"] = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
if model.bias is not None:
bias = model.bias
while len(bias.shape) < 4:
bias = bias.unsqueeze(0)
parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
return parameters
Loading

0 comments on commit f96b82d

Please sign in to comment.