-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9f2fd27
commit f96b82d
Showing
62 changed files
with
7,887 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Stable_diffusion Model | ||
|
||
## Introduction | ||
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. | ||
|
||
# Details | ||
The entry point to functional_stable_diffusion model is UNet2DConditionModel in `models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_unet_2d_condition_model.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `CompVis/stable-diffusion-v1-4` version from huggingface as our reference. | ||
|
||
# Inputs | ||
Inputs by default are provided from `input_data.json`. If you wish to change the inputs, provide a different path to test_demo.We do not recommend modifying `input_data.json` file. | ||
|
||
## How to Run | ||
|
||
To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables. | ||
For more information, refer [installation and build guide](https://github.com/tenstorrent/tt-metal/blob/main/INSTALLING.md). | ||
|
||
Use `pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` to run the demo. | ||
|
||
If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` | ||
|
||
Our second demo is designed to run poloclub/diffusiondb dataset, run this with `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb`. | ||
|
||
If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps, use `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[<num_prompts>-<num_inference_steps>]` | ||
|
||
Note: ttnn stable diffusion utilizes `PNDMScheduler` and requires `num_inference_steps to be greater than or equal to 4`. [Reference](https://arxiv.org/pdf/2202.09778) | ||
|
||
# Metrics Interpretation | ||
`FID Score (Fréchet Inception Distance)` evaluates the quality of generated images by measuring the similarity between their feature distributions and those of real images. A lower FID score indicates better similarity between generated and real images. | ||
For more information, refer [FID Score](https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html). | ||
|
||
`CLIP Score` measures the similarity between the generated images and the input prompts. Higher CLIP scores indicate better alignment between the generated images and the provided text prompts. | ||
For more information, refer [CLIP Score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html). |
41 changes: 41 additions & 0 deletions
41
models/demos/wormhole/stable_diffusion/custom_preprocessing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
from torch import nn | ||
import ttnn | ||
|
||
|
||
def preprocess_groupnorm_parameter(parameter, *, dtype): | ||
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT) | ||
return parameter | ||
|
||
|
||
def preprocess_conv_parameter(parameter, *, dtype): | ||
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT) | ||
return parameter | ||
|
||
|
||
def custom_preprocessor(model, name): | ||
parameters = {} | ||
if isinstance(model, nn.GroupNorm): | ||
parameters["weight"] = preprocess_groupnorm_parameter(model.weight, dtype=ttnn.bfloat16) | ||
parameters["bias"] = preprocess_groupnorm_parameter(model.bias, dtype=ttnn.bfloat16) | ||
|
||
if isinstance(model, nn.Conv2d): | ||
weight = torch.permute(model.weight, (2, 3, 0, 1)) | ||
parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16) | ||
parameters["bias"] = preprocess_conv_parameter(model.bias, dtype=ttnn.bfloat16) | ||
|
||
if isinstance(model, (nn.Linear, nn.LayerNorm)): | ||
weight = model.weight.T.contiguous() | ||
while len(weight.shape) < 4: | ||
weight = weight.unsqueeze(0) | ||
parameters["weight"] = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) | ||
if model.bias is not None: | ||
bias = model.bias | ||
while len(bias.shape) < 4: | ||
bias = bias.unsqueeze(0) | ||
parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) | ||
return parameters |
Oops, something went wrong.