Skip to content

Commit

Permalink
#4791: Implement Feedforward sub-module using ttnn for stable_diffusi…
Browse files Browse the repository at this point in the history
…on model
  • Loading branch information
Sudharsan-V committed Jan 31, 2024
1 parent 8d1e89a commit 7430214
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_geglu import geglu


def feedforward(config, hidden_states, parameters):
act = geglu(config, hidden_states, parameters.net[0])
output = act @ parameters.net[2].weight
output = ttnn.add(output, parameters.net[2].bias, memory_config=ttnn.L1_MEMORY_CONFIG)
return output
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@


def geglu(config, hidden_states, parameters):
# output = ttnn.linear(hidden_states, parameters.proj.weight, bias = parameters.proj.bias)
output = ttnn.matmul(hidden_states, parameters.proj.weight)
output = ttnn.add(output, parameters.proj.bias, memory_config=ttnn.L1_MEMORY_CONFIG)

hidden_states, gate = ttnn.split(output, split_size=output.shape[-1] // 2, dim=-1)
del output
act = ttnn.gelu(gate, memory_config=ttnn.L1_MEMORY_CONFIG)
del gate
return ttnn.mul(hidden_states, act, memory_config=ttnn.L1_MEMORY_CONFIG)
return ttnn.mul(hidden_states, act)
143 changes: 143 additions & 0 deletions tests/ttnn/integration_tests/stable_diffusion/test_feedforward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch
from diffusers import UNet2DConditionModel
import ttnn
from ttnn.model_preprocessing import preprocess_model_parameters

from models.experimental.functional_stable_diffusion.tt.ttnn_functional_feedforward import feedforward
from models.utility_functions import torch_random, skip_for_wormhole_b0

from tests.ttnn.utils_for_testing import assert_with_pcc


@skip_for_wormhole_b0()
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
@pytest.mark.parametrize(
"N, C, H, W, index",
[
(
1,
2,
1024,
320,
0,
),
(
1,
2,
256,
640,
1,
),
(
1,
2,
64,
1280,
2,
),
(
1,
2,
16,
1280,
2,
),
],
)
def test_feedforward_256x256(device, model_name, N, C, H, W, index, reset_seeds):
input_shapes = (N, C, H, W)
model = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet").eval()
ref_model = model.down_blocks[index].attentions[0].transformer_blocks[0].ff
config = model.config
torch_hidden_states = torch_random(input_shapes, -0.1, 0.1, dtype=torch.float32)
torch_output = ref_model(torch_hidden_states)

parameters = preprocess_model_parameters(
initialize_model=lambda: ref_model,
device=device,
)

ttnn_hidden_state = ttnn.to_layout(
ttnn.to_device(ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16), device), layout=ttnn.TILE_LAYOUT
)

output = feedforward(
config,
ttnn_hidden_state,
parameters=parameters,
)
output = ttnn.from_device(output)
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99)


@skip_for_wormhole_b0()
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
@pytest.mark.parametrize(
"N, C, H, W, index",
[
(
1,
2,
4096,
320,
3,
),
(
1,
2,
1024,
640,
2,
),
(
1,
2,
256,
1280,
1,
),
(
1,
2,
64,
1280,
1,
),
],
)
def test_feedforward_512x512(device, model_name, N, C, H, W, index, reset_seeds):
input_shapes = (N, C, H, W)
model = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet").eval()
ref_model = model.up_blocks[index].attentions[0].transformer_blocks[0].ff
config = model.config
torch_hidden_states = torch_random(input_shapes, -0.1, 0.1, dtype=torch.float32)
torch_output = ref_model(torch_hidden_states)

parameters = preprocess_model_parameters(
initialize_model=lambda: ref_model,
device=device,
)

ttnn_hidden_state = ttnn.to_layout(
ttnn.to_device(ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16), device), layout=ttnn.TILE_LAYOUT
)

output = feedforward(
config,
ttnn_hidden_state,
parameters=parameters,
)
output = ttnn.from_device(output)
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99)

0 comments on commit 7430214

Please sign in to comment.