-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4789: Add upsample2d to functional_stable_diffusion model
- Loading branch information
1 parent
e383db3
commit 5571e5a
Showing
2 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
54 changes: 54 additions & 0 deletions
54
models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import ttnn | ||
|
||
from models.utility_functions import ( | ||
torch_to_tt_tensor_rm, | ||
tt_to_torch_tensor, | ||
) | ||
|
||
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_nearest_2d import upsample_nearest2d | ||
from tt_lib.fallback_ops import fallback_ops | ||
|
||
|
||
def upsample2d( | ||
device, | ||
input, | ||
parameters, | ||
in_channels, | ||
out_channels, | ||
scale_factor=2, | ||
): | ||
tt_out = upsample_nearest2d(input, scale_factor) | ||
|
||
tt_out = ttnn.from_device(tt_out) | ||
tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) | ||
tt_out = ttnn.to_torch(tt_out) | ||
tt_out = torch_to_tt_tensor_rm(tt_out, device) | ||
|
||
weight = ttnn.to_layout(parameters.conv.weight, layout=ttnn.ROW_MAJOR_LAYOUT) | ||
weight = ttnn.to_torch(weight) | ||
weight = torch.permute(weight, (2, 3, 0, 1)) | ||
bias = ttnn.to_layout(parameters.conv.bias, layout=ttnn.ROW_MAJOR_LAYOUT) | ||
bias = ttnn.to_torch(bias) | ||
|
||
weight = torch_to_tt_tensor_rm(weight, device, put_on_device=False) | ||
bias = torch_to_tt_tensor_rm(bias, device, put_on_device=False) | ||
|
||
conv = fallback_ops.Conv2d( | ||
weight, | ||
bias, | ||
in_channels, | ||
out_channels, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
) | ||
|
||
tt_out = conv(tt_out) | ||
torch_out = tt_to_torch_tensor(tt_out) | ||
ttnn_out = ttnn.from_torch(torch_out, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) | ||
return ttnn_out |
112 changes: 112 additions & 0 deletions
112
tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
from diffusers import StableDiffusionPipeline | ||
import pytest | ||
import ttnn | ||
|
||
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d | ||
from models.experimental.functional_stable_diffusion.custom_preprocessing import custom_preprocessor | ||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
|
||
from models.utility_functions import skip_for_wormhole_b0, tt_to_torch_tensor, torch_random | ||
|
||
|
||
def torch_to_ttnn(input, device, layout=ttnn.TILE_LAYOUT): | ||
input = ttnn.from_torch(input, ttnn.bfloat16) | ||
input = ttnn.to_layout(input, layout) | ||
input = ttnn.to_device(input, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) | ||
return input | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size, in_channels, input_height, input_width, index", | ||
[ | ||
(2, 1280, 4, 4, 0), | ||
(2, 1280, 8, 8, 1), | ||
(2, 640, 16, 16, 2), | ||
], | ||
) | ||
@pytest.mark.parametrize("scale_factor", [2]) | ||
@skip_for_wormhole_b0() | ||
def test_upsample2d_256x256(device, scale_factor, batch_size, in_channels, input_height, input_width, index): | ||
# setup pytorch model | ||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32) | ||
|
||
unet = pipe.unet | ||
unet.eval() | ||
state_dict = unet.state_dict() | ||
unet_upblock = pipe.unet.up_blocks[index] | ||
resnet_upsampler = unet_upblock.upsamplers[0] | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device | ||
) | ||
parameters = parameters.up_blocks[index].upsamplers[0] | ||
|
||
input_shape = batch_size, in_channels, input_height, input_width | ||
out_channels = in_channels | ||
|
||
input = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) | ||
torch_output = resnet_upsampler(input) | ||
|
||
tt_input_tensor = ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) | ||
tt_up = upsample2d( | ||
device, | ||
tt_input_tensor, | ||
parameters, | ||
in_channels, | ||
out_channels, | ||
scale_factor, | ||
) | ||
torch_up = ttnn.to_torch(tt_up) | ||
|
||
assert_with_pcc(torch_output, torch_up, 0.99) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size, in_channels, input_height, input_width, index", | ||
[ | ||
(2, 1280, 8, 8, 0), | ||
(2, 1280, 16, 16, 1), | ||
(2, 640, 32, 32, 2), | ||
], | ||
) | ||
@pytest.mark.parametrize("scale_factor", [2]) | ||
@skip_for_wormhole_b0() | ||
def test_upsample2d_512x512(device, scale_factor, batch_size, in_channels, input_height, input_width, index): | ||
# setup pytorch model | ||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32) | ||
|
||
unet = pipe.unet | ||
unet.eval() | ||
state_dict = unet.state_dict() | ||
unet_upblock = pipe.unet.up_blocks[index] | ||
resnet_upsampler = unet_upblock.upsamplers[0] | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device | ||
) | ||
parameters = parameters.up_blocks[index].upsamplers[0] | ||
|
||
input_shape = batch_size, in_channels, input_height, input_width | ||
out_channels = in_channels | ||
|
||
input = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) | ||
torch_output = resnet_upsampler(input) | ||
|
||
tt_input_tensor = ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) | ||
tt_up = upsample2d( | ||
device, | ||
tt_input_tensor, | ||
parameters, | ||
in_channels, | ||
out_channels, | ||
scale_factor, | ||
) | ||
torch_up = ttnn.to_torch(tt_up) | ||
|
||
assert_with_pcc(torch_output, torch_up, 0.99) |