Skip to content

Commit

Permalink
#4789: Add upsample2d to functional_stable_diffusion model
Browse files Browse the repository at this point in the history
  • Loading branch information
jayasuryamaganuru authored and saichandax committed Feb 6, 2024
1 parent e383db3 commit 5571e5a
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

from models.utility_functions import (
torch_to_tt_tensor_rm,
tt_to_torch_tensor,
)

from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_nearest_2d import upsample_nearest2d
from tt_lib.fallback_ops import fallback_ops


def upsample2d(
device,
input,
parameters,
in_channels,
out_channels,
scale_factor=2,
):
tt_out = upsample_nearest2d(input, scale_factor)

tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT)
tt_out = ttnn.to_torch(tt_out)
tt_out = torch_to_tt_tensor_rm(tt_out, device)

weight = ttnn.to_layout(parameters.conv.weight, layout=ttnn.ROW_MAJOR_LAYOUT)
weight = ttnn.to_torch(weight)
weight = torch.permute(weight, (2, 3, 0, 1))
bias = ttnn.to_layout(parameters.conv.bias, layout=ttnn.ROW_MAJOR_LAYOUT)
bias = ttnn.to_torch(bias)

weight = torch_to_tt_tensor_rm(weight, device, put_on_device=False)
bias = torch_to_tt_tensor_rm(bias, device, put_on_device=False)

conv = fallback_ops.Conv2d(
weight,
bias,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
)

tt_out = conv(tt_out)
torch_out = tt_to_torch_tensor(tt_out)
ttnn_out = ttnn.from_torch(torch_out, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
return ttnn_out
112 changes: 112 additions & 0 deletions tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
from diffusers import StableDiffusionPipeline
import pytest
import ttnn

from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d
from models.experimental.functional_stable_diffusion.custom_preprocessing import custom_preprocessor
from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import preprocess_model_parameters

from models.utility_functions import skip_for_wormhole_b0, tt_to_torch_tensor, torch_random


def torch_to_ttnn(input, device, layout=ttnn.TILE_LAYOUT):
input = ttnn.from_torch(input, ttnn.bfloat16)
input = ttnn.to_layout(input, layout)
input = ttnn.to_device(input, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
return input


@pytest.mark.parametrize(
"batch_size, in_channels, input_height, input_width, index",
[
(2, 1280, 4, 4, 0),
(2, 1280, 8, 8, 1),
(2, 640, 16, 16, 2),
],
)
@pytest.mark.parametrize("scale_factor", [2])
@skip_for_wormhole_b0()
def test_upsample2d_256x256(device, scale_factor, batch_size, in_channels, input_height, input_width, index):
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)

unet = pipe.unet
unet.eval()
state_dict = unet.state_dict()
unet_upblock = pipe.unet.up_blocks[index]
resnet_upsampler = unet_upblock.upsamplers[0]

parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
parameters = parameters.up_blocks[index].upsamplers[0]

input_shape = batch_size, in_channels, input_height, input_width
out_channels = in_channels

input = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32)
torch_output = resnet_upsampler(input)

tt_input_tensor = ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
tt_up = upsample2d(
device,
tt_input_tensor,
parameters,
in_channels,
out_channels,
scale_factor,
)
torch_up = ttnn.to_torch(tt_up)

assert_with_pcc(torch_output, torch_up, 0.99)


@pytest.mark.parametrize(
"batch_size, in_channels, input_height, input_width, index",
[
(2, 1280, 8, 8, 0),
(2, 1280, 16, 16, 1),
(2, 640, 32, 32, 2),
],
)
@pytest.mark.parametrize("scale_factor", [2])
@skip_for_wormhole_b0()
def test_upsample2d_512x512(device, scale_factor, batch_size, in_channels, input_height, input_width, index):
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)

unet = pipe.unet
unet.eval()
state_dict = unet.state_dict()
unet_upblock = pipe.unet.up_blocks[index]
resnet_upsampler = unet_upblock.upsamplers[0]

parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
parameters = parameters.up_blocks[index].upsamplers[0]

input_shape = batch_size, in_channels, input_height, input_width
out_channels = in_channels

input = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32)
torch_output = resnet_upsampler(input)

tt_input_tensor = ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
tt_up = upsample2d(
device,
tt_input_tensor,
parameters,
in_channels,
out_channels,
scale_factor,
)
torch_up = ttnn.to_torch(tt_up)

assert_with_pcc(torch_output, torch_up, 0.99)

0 comments on commit 5571e5a

Please sign in to comment.