Skip to content

Commit

Permalink
#4557: Uplift swin model to resolve errors in tests & Add test_accura…
Browse files Browse the repository at this point in the history
…cy script
  • Loading branch information
jayasuryamaganuru committed Jan 19, 2024
1 parent a536801 commit 17b7b74
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 64 deletions.
59 changes: 45 additions & 14 deletions models/experimental/swin/swin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

from typing import List, Tuple, Union, Optional
from packaging import version
from collections import OrderedDict
from PIL import Image
import os
import glob
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES

from models.utility_functions import (
tt_to_torch_tensor,
Expand All @@ -30,9 +35,7 @@ def meshgrid(
return torch.meshgrid(*tensors, indexing=indexing)
else:
if indexing != "ij":
raise ValueError(
'torch.meshgrid only supports `indexing="ij"` for torch<1.10.'
)
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
return torch.meshgrid(*tensors)


Expand All @@ -50,11 +53,7 @@ def window_partition(input_feature, window_size, device, put_on_device=True):
window_size,
num_channels,
)
windows = (
input_feature.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, window_size, window_size, num_channels)
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)

windows = torch_to_tt_tensor_rm(windows, device, put_on_device=put_on_device)
return windows
Expand All @@ -65,7 +64,7 @@ def window_reverse(windows, window_size, height, width, device, put_on_device=Tr
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape()[-1]
windows = tt_to_torch_tensor(windows, device)
windows = tt_to_torch_tensor(windows)
windows = windows.view(
-1,
height // window_size,
Expand All @@ -74,11 +73,7 @@ def window_reverse(windows, window_size, height, width, device, put_on_device=Tr
window_size,
num_channels,
)
windows = (
windows.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, height, width, num_channels)
)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)

windows = torch_to_tt_tensor_rm(windows, device, put_on_device=put_on_device)
return windows
Expand All @@ -92,3 +87,39 @@ def get_shape(shape):
else:
new_shape = shape
return new_shape


class InputExample(object):
def __init__(self, image, label=None):
self.image = image
self.label = label


def get_input(image_path):
img = Image.open(image_path)
return img


def get_label(image_path):
_, image_name = image_path.rsplit("/", 1)
image_name_exact, _ = image_name.rsplit(".", 1)
_, label_id = image_name_exact.rsplit("_", 1)
label = list(IMAGENET2012_CLASSES).index(label_id)
return label


def get_data(input_loc):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = sorted(glob.glob(data_path))
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
image_examples = examples

return image_examples
134 changes: 134 additions & 0 deletions models/experimental/swin/tests/test_perf_accuracy_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import tt_lib
import pytest
import evaluate

import os
import random
from pathlib import Path
from loguru import logger
from transformers import AutoFeatureExtractor
from PIL import Image
import torchvision.transforms as transforms
from transformers import AutoImageProcessor

from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES as labels_dict
from models.experimental.swin.tt.swin_for_image_classification import (
TtSwinForImageClassification,
)
from transformers import SwinForImageClassification as HF_SwinForImageClassification

from models.utility_functions import (
profiler,
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
torch_to_tt_tensor_rm,
tt_to_torch_tensor,
)
from models.perf.perf_utils import prep_perf_report
from models.experimental.swin.swin_utils import get_data


def run_swin_perf(device, model_name, iterations, model_location_generator):
first_key = "first_iter"
second_key = "second_iter"
third_key = "third_iter"
cpu_key = "ref_key"

feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = HF_SwinForImageClassification.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)

folder_path = str(model_location_generator("ImageNet_data"))
image_examples = get_data(folder_path)
ground_truth = []
predicted_label = []

disable_persistent_kernel_cache()
base_address = f"swin."
with torch.no_grad():
torch_model = model

tt_model = TtSwinForImageClassification(
config=model.config,
state_dict=model.state_dict(),
base_address=base_address,
device=device,
)

transform = transforms.Compose([transforms.ToTensor()])
profiler.start(cpu_key)
torch_input = transform(image_examples[0].image)
torch_input = torch.unsqueeze(torch_input, 0)
torch_output = torch_model(torch_input)
tt_lib.device.Synchronize(device)
profiler.end(cpu_key)

profiler.start(first_key)
input_image = image_examples[0].image
input = image_processor(input_image, return_tensors="pt")
tt_pixel_values = torch_to_tt_tensor_rm(input.pixel_values, device)
tt_output = tt_model(tt_pixel_values)
profiler.end(first_key)
del tt_output

enable_persistent_kernel_cache()
profiler.start(second_key)
tt_pixel_values = torch_to_tt_tensor_rm(input.pixel_values, device)
tt_output = tt_model(tt_pixel_values)
profiler.end(second_key)
del tt_output

profiler.start(third_key)
tt_lib.device.Synchronize(device)
for i in range(iterations):
input_image = image_examples[i].image
input = image_processor(input_image, return_tensors="pt")

tt_pixel_values = input.pixel_values
tt_pixel_values = torch_to_tt_tensor_rm(tt_pixel_values, device)
tt_output = tt_model(tt_pixel_values)
tt_output_torch = tt_to_torch_tensor(tt_output.logits)
tt_prediction = torch.argmax(tt_output_torch)

ground_truth.append(image_examples[i].label)
predicted_label.append(tt_prediction.item())
del tt_output, tt_output_torch, tt_prediction
profiler.end(third_key)

accuracy_metric = evaluate.load("accuracy")
accuracy = accuracy_metric.compute(references=ground_truth, predictions=predicted_label)
logger.info(f"Accuracy: {accuracy}")

first_iter_time = profiler.get(first_key)
second_iter_time = profiler.get(second_key)
third_iter_time = profiler.get(third_key)
cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - second_iter_time

prep_perf_report("Swin", 1, first_iter_time, second_iter_time, 100, 100, "", cpu_time)
logger.info(f"Swin inference time: {second_iter_time}")
logger.info(f"Swin compile time: {compile_time}")
logger.info(f"Swin inference for {iterations} samples: {third_iter_time}")


@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"model_name,iterations",
(("microsoft/swin-tiny-patch4-window7-224", 20),),
)
def test_perf_bare_metal(use_program_cache, device, model_name, iterations, model_location_generator):
run_swin_perf(device, model_name, iterations, model_location_generator)


@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize(
"model_name,iterations",
(("microsoft/swin-tiny-patch4-window7-224", 20),),
)
def test_perf_virtual_machine(use_program_cache, device, model_name, iterations, model_location_generator):
run_swin_perf(device, model_name, iterations, model_location_generator)
63 changes: 16 additions & 47 deletions models/experimental/swin/tt/swin_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ def __init__(
self.window_size = config.window_size
self.input_resolution = input_resolution

gamma_before = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.layernorm_before.weight"], self.device
)
beta_before = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.layernorm_before.bias"], self.device
)
gamma_before = torch_to_tt_tensor_rm(state_dict[f"{base_address}.layernorm_before.weight"], self.device)
beta_before = torch_to_tt_tensor_rm(state_dict[f"{base_address}.layernorm_before.bias"], self.device)
self.LayerNorm_before = fallback_ops.LayerNorm(
gamma_before, beta_before, normalized_shape=dim, eps=config.layer_norm_eps
)
Expand All @@ -62,12 +58,8 @@ def __init__(
device,
)

gamma_after = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.layernorm_after.weight"], self.device
)
beta_after = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.layernorm_after.bias"], self.device
)
gamma_after = torch_to_tt_tensor_rm(state_dict[f"{base_address}.layernorm_after.weight"], self.device)
beta_after = torch_to_tt_tensor_rm(state_dict[f"{base_address}.layernorm_after.bias"], self.device)

self.LayerNorm_after = fallback_ops.LayerNorm(
gamma_after,
Expand Down Expand Up @@ -121,15 +113,11 @@ def get_attn_mask(self, height, width, dtype):
img_mask = torch_to_tt_tensor_rm(img_mask, self.device, put_on_device=False)
mask_windows = window_partition(img_mask, self.window_size, self.device, False)

mask_windows = fallback_ops.reshape(
mask_windows, -1, self.window_size * self.window_size, 1, 1
)
mask_windows = fallback_ops.reshape(mask_windows, -1, self.window_size * self.window_size, 1, 1)

mask_windows = tt_to_torch_tensor(mask_windows).squeeze()
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

attn_mask = torch_to_tt_tensor_rm(attn_mask, self.device, put_on_device=False)
else:
Expand Down Expand Up @@ -163,9 +151,7 @@ def forward(

hidden_states = self.LayerNorm_before(hidden_states)

hidden_states = fallback_ops.reshape(
hidden_states, batch_size, height, width, channels
)
hidden_states = fallback_ops.reshape(hidden_states, batch_size, height, width, channels)

# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
Expand All @@ -174,28 +160,21 @@ def forward(
hidden_states = tt_to_torch_tensor(hidden_states)
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(
hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states

# partition windows
shifted_hidden_states = torch_to_tt_tensor_rm(
shifted_hidden_states, self.device
shifted_hidden_states = torch_to_tt_tensor_rm(shifted_hidden_states, self.device)
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size, self.device, False).to(
self.device
)
hidden_states_windows = window_partition(
shifted_hidden_states, self.window_size, self.device, False
).to(self.device)

hidden_states_windows = fallback_ops.reshape(
hidden_states_windows, -1, self.window_size * self.window_size, channels, 1
hidden_states_windows, 1, -1, self.window_size * self.window_size, channels
)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)

if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device())

attention_outputs = self.attention(
hidden_states_windows,
attn_mask,
Expand All @@ -205,9 +184,7 @@ def forward(

attention_output = attention_outputs[0]

attention_windows = fallback_ops.reshape(
attention_output, -1, self.window_size, self.window_size, channels
)
attention_windows = fallback_ops.reshape(attention_output, -1, self.window_size, self.window_size, channels)

shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad, self.device, False
Expand All @@ -216,28 +193,20 @@ def forward(
shifted_windows = tt_to_torch_tensor(shifted_windows)
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = torch.roll(
shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
attention_windows = torch_to_tt_tensor_rm(attention_windows, self.device)

was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :]
attention_windows = fallback_ops.reshape(
attention_windows, 1, batch_size, height * width, channels
)
attention_windows = fallback_ops.reshape(attention_windows, 1, batch_size, height * width, channels)
hidden_states = tt_lib.tensor.add(shortcut, attention_windows)

layer_output = self.LayerNorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = tt_lib.tensor.add(hidden_states, self.output(layer_output))

layer_outputs = (
(layer_output, attention_outputs[1])
if output_attentions
else (layer_output,)
)
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
4 changes: 1 addition & 3 deletions models/experimental/swin/tt/swin_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,11 @@ def forward(
relative_position_bias = torch_to_tt_tensor_rm(relative_position_bias, self.device, put_on_device=False)
relative_position_bias = fallback_ops.reshape(
relative_position_bias,
-1,
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
1,
)
relative_position_bias = tt_lib.tensor.permute(relative_position_bias, (2, 0, 1, 3))

attention_scores = tt_lib.tensor.permute(attention_scores, (1, 2, 3, 0))
attention_scores = tt_lib.tensor.bcast(
attention_scores,
Expand Down

0 comments on commit 17b7b74

Please sign in to comment.