diff --git a/models/demos/swin/README.md b/models/demos/swin/README.md new file mode 100644 index 00000000000..86b1b5f3ee3 --- /dev/null +++ b/models/demos/swin/README.md @@ -0,0 +1,20 @@ +## Swin Model + +# Platforms: + GS E150, WH N150, WH N300 + +## Introduction +The Swin Transformer is a variant of the Vision Transformer that generates hierarchical feature maps by progressively merging image patches in its deeper layers. It achieves linear computational complexity relative to the input image size by restricting self-attention calculations to local windows. This design allows it to function as a versatile backbone for tasks like image classification and dense recognition. In contrast, earlier Vision Transformers generate feature maps at a single low resolution and have quadratic computational complexity, as they compute self-attention across the entire image. + +# Details +The entry point to swin model is swin_for_image_classification in `models/demos/swin/tt/tt/ttnn_optimized_swin.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `microsoft/swin-tiny-patch4-window7-224` version from huggingface as our reference. + + +## Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8 + +Use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8-5-device_params0]` to run the ttnn_optimized_swin demo. + + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8--device_params0]` diff --git a/models/demos/swin/demo/demo.py b/models/demos/swin/demo/demo.py new file mode 100644 index 00000000000..49c216e8e30 --- /dev/null +++ b/models/demos/swin/demo/demo.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from loguru import logger +import pytest + + +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + enable_persistent_kernel_cache, + profiler, +) +import ttnn + +from models.demos.swin.demo_utils import get_data_loader, get_batch, preprocess +from loguru import logger +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.swin.tt import ttnn_optimized_swin +from transformers import SwinForImageClassification as HF_SwinForImageClassification +from models.demos.swin.tt.swin_utils import get_relative_position, get_attn_mask + + +def run_swin_imagenet_inference( + batch_size, + iterations, + imagenet_label_dict, + model_location_generator, + device, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + profiler.clear() + + # Setup model + torch_model = HF_SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + config = torch_model.config + torch_model.to(torch.bfloat16) + torch_model.eval() + + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + model_name=torch_model, + device=device, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + ) + + # load inputs + logger.info("ImageNet-1k validation Dataset") + input_loc = str(model_location_generator("ImageNet_data")) + data_loader = get_data_loader(input_loc, batch_size, iterations) + + bias_table = get_relative_position(torch_model.config, parameters.swin, device) + attention_mask_list = get_attn_mask(torch_model.config, device) + + # load ImageNet batch by batch + # and run inference + correct = 0 + torch_ttnn_correct = 0 + torch_correct = 0 + for iter in range(iterations): + predictions = [] + torch_predictions = [] + inputs, labels = get_batch(data_loader) + torch_outputs = torch_model(inputs) + tt_batched_input_tensor = ttnn.from_torch(inputs, ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + tt_output = ttnn_optimized_swin.swin_for_image_classification( + torch_model.config, + tt_batched_input_tensor, + parameters=parameters, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + tt_output = ttnn.to_torch(tt_output) + prediction = tt_output.argmax(dim=-1) + torch_prediction = torch_outputs[0].argmax(dim=-1) + for i in range(batch_size): + predictions.append(imagenet_label_dict[prediction[i].item()]) + torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()]) + logger.info( + f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}" + ) + if imagenet_label_dict[labels[i]] == predictions[-1]: + correct += 1 + if imagenet_label_dict[labels[i]] == torch_predictions[-1]: + torch_correct += 1 + if predictions[-1] == torch_predictions[-1]: + torch_ttnn_correct += 1 + + del tt_output, tt_batched_input_tensor, inputs, labels, predictions + accuracy = correct / (batch_size * iterations) + torch_accuracy = torch_correct / (batch_size * iterations) + torch_ttnn_accuracy = torch_ttnn_correct / (batch_size * iterations) + + logger.info(f"Model Swin for Image Classification") + logger.info(f"TTNN Accuracy for {batch_size}x{iterations} inputs: {accuracy}") + logger.info(f"Torch Accuracy for {batch_size}x{iterations} inputs: {torch_accuracy}") + logger.info(f"Torch vs TTNN Accuracy for {batch_size}x{iterations} inputs: {torch_ttnn_accuracy}") + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, iterations", + ((8, 5),), +) +def test_demo_imagenet(batch_size, iterations, imagenet_label_dict, model_location_generator, device): + run_swin_imagenet_inference(batch_size, iterations, imagenet_label_dict, model_location_generator, device) diff --git a/models/demos/swin/demo_utils.py b/models/demos/swin/demo_utils.py new file mode 100644 index 00000000000..32296462f01 --- /dev/null +++ b/models/demos/swin/demo_utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from PIL import Image +import torch +import os +import glob +from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES +from datasets import load_dataset +from torchvision import models +from PIL import Image +import torchvision.transforms as transforms +import torch + + +class InputExample(object): + def __init__(self, image, label=None): + self.image = image + self.label = label + + +def get_input(image_path): + img = Image.open(image_path) + return img + + +def get_label(image_path): + _, image_name = image_path.rsplit("/", 1) + image_name_exact, _ = image_name.rsplit(".", 1) + _, label_id = image_name_exact.rsplit("_", 1) + label = list(IMAGENET2012_CLASSES).index(label_id) + return label + + +preprocess = transforms.Compose( + [ + transforms.Resize(256), # Resize the shorter side to 256 pixels + transforms.CenterCrop(224), # Crop the center to 224x224 pixels + transforms.ToTensor(), # Convert the image to a tensor + transforms.Normalize( # Normalize using ImageNet's mean and std + mean=[0.485, 0.456, 0.406], # These are the mean values for each channel + std=[0.229, 0.224, 0.225], # These are the std values for each channel + ), + ] +) + + +def get_batch(data_loader): + loaded_images = next(data_loader) + images = None + labels = [] + transform = transforms.ToTensor() + resize_transform = transforms.Resize((224, 224)) + for image in loaded_images: + img = image.image + labels.append(image.label) + if img.mode == "L": + img = img.convert(mode="RGB") + + img = preprocess(img) + img = img.to(torch.bfloat16) + img = img.unsqueeze(0) + if images is None: + images = img + else: + images = torch.cat((images, img), dim=0) + return images, labels + + +def get_data_loader(input_loc, batch_size, iterations): + img_dir = input_loc + "/" + data_path = os.path.join(img_dir, "*G") + files = glob.glob(data_path) + + def loader(): + examples = [] + for f1 in files: + examples.append( + InputExample( + image=get_input(f1), + label=get_label(f1), + ) + ) + if len(examples) == batch_size: + yield examples + del examples + examples = [] + + def loader_hf(): + examples = [] + for f1 in files: + examples.append( + InputExample( + image=f1["image"], + label=f1["label"], + ) + ) + if len(examples) == batch_size: + yield examples + del examples + examples = [] + + if len(files) == 0: + files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True)) + files = [] + sample_count = batch_size * iterations + for _ in range(sample_count): + files.append(next(files_raw)) + del files_raw + return loader_hf() + + return loader() + + +def get_data(input_loc): + img_dir = input_loc + "/" + data_path = os.path.join(img_dir, "*G") + files = sorted(glob.glob(data_path)) + examples = [] + for f1 in files: + examples.append( + InputExample( + image=get_input(f1), + label=get_label(f1), + ) + ) + image_examples = examples + + return image_examples diff --git a/models/demos/swin/tests/test_perf_swin.py b/models/demos/swin/tests/test_perf_swin.py new file mode 100644 index 00000000000..6178903f5c6 --- /dev/null +++ b/models/demos/swin/tests/test_perf_swin.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +import time + +from models.demos.swin.tt import ttnn_optimized_swin +from transformers import SwinForImageClassification as HF_SwinForImageClassification +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + profiler, +) +from ttnn.model_preprocessing import ( + preprocess_model_parameters, +) +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.utility_functions import is_grayskull, is_wormhole_b0 +from models.demos.swin.tt.swin_utils import get_relative_position, get_attn_mask + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ([8, 4, 50.00],), +) +def test_performance_swin( + batch_size, + model_name, + expected_inference_time, + expected_compile_time, + device, +): + hugging_face_reference_model = HF_SwinForImageClassification.from_pretrained(model_name) + hugging_face_reference_model.eval() + pixel_values = torch.rand(batch_size, 3, 224, 224) + # set up tokenizer + disable_persistent_kernel_cache() + + profiler.start(f"preprocessing_parameter") + parameters = preprocess_model_parameters( + model_name="ttnn_optimized_swin", + initialize_model=lambda: hugging_face_reference_model, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + profiler.end(f"preprocessing_parameter") + + cpu_key = "ref_key" + + with torch.no_grad(): + profiler.start(cpu_key) + torch_out = hugging_face_reference_model(pixel_values) + profiler.end(cpu_key) + + durations = [] + for _ in range(2): + profiler.start(f"preprocessing_input") + tt_pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(hugging_face_reference_model.config, parameters.swin, device) + attention_mask_list = get_attn_mask(hugging_face_reference_model.config, device) + + profiler.end(f"preprocessing_input") + + start = time.time() + tt_output = ttnn_optimized_swin.swin_for_image_classification( + hugging_face_reference_model.config, + pixel_values=tt_pixel_values, + parameters=parameters, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + prep_perf_report( + model_name=f"ttnn_{model_name}_optimized", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit Swin perf test") + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [8, "microsoft/swin-tiny-patch4-window7-224"], + ], +) +def test_swin_perf_device(batch_size, test, reset_seeds): + subdir = "ttnn_swin" + margin = 0.03 + num_iterations = 1 + if is_grayskull(): + expected_perf = 26 + elif is_wormhole_b0(): + expected_perf = 39 + + command = f"pytest tests/ttnn/integration_tests/swin/test_ttnn_swin.py::test_swin_for_image_classification" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_swin{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/swin/tt/swin_utils.py b/models/demos/swin/tt/swin_utils.py new file mode 100644 index 00000000000..5e74076bf2f --- /dev/null +++ b/models/demos/swin/tt/swin_utils.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import ttnn +import torch + + +def window_partition(input_feature, window_size, device): + batch_size, height, width, num_channels = input_feature.shape + input_feature = ttnn.to_torch(input_feature) + # 6D reshape and permute not supported in ttnn + input_feature = torch.reshape( + input_feature, + ( + batch_size, + height // window_size, + window_size, + width // window_size, + window_size, + num_channels, + ), + ) + windows = torch.permute(input_feature, (0, 1, 3, 2, 4, 5)) + windows = torch.reshape(windows, (-1, window_size, window_size, num_channels)) + windows = ttnn.from_torch(windows, dtype=ttnn.bfloat16) + return windows + + +def window_reverse(windows, window_size, height, width, device): + num_channels = windows.shape[-1] + + windows = ttnn.to_torch(windows) + # 6D reshape and permute not supported in ttnn + windows = torch.reshape( + windows, + ( + -1, + height // window_size, + width // window_size, + window_size, + window_size, + num_channels, + ), + ) + + windows = torch.permute(windows, (0, 1, 3, 2, 4, 5)) + windows = torch.reshape(windows, (-1, height, width, num_channels)) + windows = ttnn.from_torch(windows, dtype=ttnn.bfloat16, device=device) + + return windows + + +def get_relative_position(config, parameters, device): + window_size = 7 + + coords_h = torch.arange(window_size) + coords_w = torch.arange(window_size) + cords = [coords_h, coords_w] + coords = torch.stack(torch.meshgrid(*cords, indexing="ij")) + + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + + relative_coords[:, :, 0] += window_size - 1 + relative_coords[:, :, 1] += window_size - 1 + relative_coords[:, :, 0] *= 2 * window_size - 1 + relative_position_index = relative_coords.sum(-1) + + relative_pos_table = {} + depth = 2 + for i in range(4): + if i == 2: + depth = 6 + else: + depth = 2 + j = 0 + bias_table = [] + while j < depth: + relative_position_bias_table = ttnn.to_torch( + parameters.encoder.layers[i].blocks[j].attention.relative_position_bias_table + ) + + relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + window_size * window_size, window_size * window_size, -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + relative_position_bias = ttnn.from_torch( + relative_position_bias.unsqueeze(0), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + bias_table.append(relative_position_bias) + j += 1 + relative_pos_table[i] = bias_table + return relative_pos_table + + +# def get_attn_mask() +def get_attn_mask(config, device): + height = [56, 28, 14, 7] + shift_size = [0, 3, 0, 3, 0, 3, 0, 3, 0, 3, 0, 3] + window_size = 7 + shift_idx = 0 + attention_mask_list = [] + for i in range(len(config.depths)): + mask_list = [] + for j in range(config.depths[i]): + if shift_size[shift_idx] > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height[i], height[i], 1)) + height_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size[shift_idx]), + slice(-shift_size[shift_idx], None), + ) + width_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size[shift_idx]), + slice(-shift_size[shift_idx], None), + ) + + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + img_mask = ttnn.from_torch(img_mask, dtype=ttnn.bfloat16) + mask_windows = window_partition(img_mask, window_size, device) + + mask_windows = ttnn.reshape(mask_windows, (-1, window_size * window_size, 1, 1)) + + mask_windows = ttnn.to_torch(mask_windows) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + attn_mask = ttnn.from_torch(attn_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + attn_mask = ttnn.to_device(attn_mask, device=device) + else: + attn_mask = None + + mask_list.append(attn_mask) + shift_idx += 1 + + attention_mask_list.append(mask_list) + + return attention_mask_list diff --git a/models/demos/swin/tt/ttnn_optimized_swin.py b/models/demos/swin/tt/ttnn_optimized_swin.py new file mode 100644 index 00000000000..e5784a6b5d1 --- /dev/null +++ b/models/demos/swin/tt/ttnn_optimized_swin.py @@ -0,0 +1,625 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from torch import nn +from models.demos.swin.tt.swin_utils import window_partition, window_reverse +from ttnn.model_preprocessing import ( + preprocess_linear_bias, + preprocess_linear_weight, +) + + +def patch_embeddings(config, pixel_values, parameters, device): + _, num_channels, height, width = pixel_values.shape + pixel_values = ttnn.to_torch(pixel_values).to(torch.float) + weight = ttnn.to_torch(parameters.embeddings.patch_embeddings.projection.weight).to(torch.float) + bias = ttnn.to_torch(parameters.embeddings.patch_embeddings.projection.bias).to(torch.float) + projection = nn.Conv2d( + in_channels=3, + out_channels=96, + kernel_size=4, + stride=4, + padding=0, + ) + projection.weight = nn.Parameter(weight) + projection.bias = nn.Parameter(bias.squeeze(0).squeeze(0).squeeze(0)) + embeddings = projection(pixel_values) + + embeddings = ttnn.from_torch(embeddings, dtype=ttnn.bfloat16) + + batch, channel, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = ttnn.reshape(embeddings, (1, batch, channel, height * width)) + + embeddings = ttnn.to_layout(embeddings, layout=ttnn.TILE_LAYOUT) + embeddings = ttnn.to_device(embeddings, device=device) + embeddings = ttnn.permute(embeddings, (0, 1, 3, 2)) + embeddings = ttnn.reshape(embeddings, (embeddings.shape[1], embeddings.shape[2], embeddings.shape[3])) + return embeddings, output_dimensions + + +def embeddings(config, pixel_values, position_embeddings=None, bool_masked_pos=None, parameters=None, device=None): + embeddings, output_dimensions = patch_embeddings(config, pixel_values, parameters, device) + embeddings = ttnn.layer_norm( + embeddings, + weight=parameters.embeddings.norm.weight, + bias=parameters.embeddings.norm.bias, + epsilon=1e-05, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if config.use_absolute_embeddings and position_embeddings is not None: + embeddings = embeddings + position_embeddings + + return embeddings, output_dimensions + + +def self_attention( + config, + dim, + num_heads, + window_size, + hidden_states, + attention_mask, + head_mask=None, + output_attentions=None, + parameters=None, + device=None, + relative_position_bias=None, +): + batch_size, c, num_channels = hidden_states.shape + num_attention_heads = num_heads + attention_head_size = int(dim / num_heads) + all_head_size = num_attention_heads * attention_head_size + + query_key_value_output = ttnn.linear( + hidden_states, + parameters.query_key_value.weight, + bias=parameters.query_key_value.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + ( + query_layer, + key_layer, + value_layer, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value_output, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=num_heads, + ) + attention_scores = ttnn.matmul(query_layer, key_layer) + attention_head_size = int(dim / num_heads) + + attention_scores = ttnn.mul(attention_scores, (1 / (attention_head_size ** (1 / 2)))) + + attention_scores = ttnn.add( + attention_scores, + relative_position_bias, + ) + + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + + attention_scores = ttnn.to_layout(ttnn.from_device(attention_scores), layout=ttnn.ROW_MAJOR_LAYOUT) + attention_scores = ttnn.reshape( + attention_scores, + ( + batch_size // mask_shape, + mask_shape, + num_heads, + c, + c, + ), + ) + attention_scores = ttnn.to_layout(attention_scores, layout=ttnn.TILE_LAYOUT) + attention_scores = ttnn.to_device(attention_scores, device=device) + + attention_mask = ttnn.from_device(attention_mask) + attention_mask = ttnn.to_layout(attention_mask, layout=ttnn.ROW_MAJOR_LAYOUT) + + attention_mask = ttnn.reshape( + attention_mask, (1, attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[2]) + ) + + attention_mask = ttnn.to_layout(attention_mask, layout=ttnn.TILE_LAYOUT) + attention_mask = ttnn.to_device(attention_mask, device=device) + + attention_scores = attention_scores + attention_mask + attention_scores = ttnn.from_device(attention_scores) + attention_scores = ttnn.to_layout(attention_scores, layout=ttnn.ROW_MAJOR_LAYOUT) + attention_scores = ttnn.reshape(attention_scores, (-1, num_heads, c, c)) + + attention_scores = ttnn.to_layout(attention_scores, layout=ttnn.TILE_LAYOUT) + attention_scores = ttnn.to_device(attention_scores, device=device) + attention_probs = ttnn.softmax(attention_scores, dim=-1) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + context_layer = ttnn.matmul(attention_probs, value_layer) + context_layer = ttnn.permute(context_layer, (0, 2, 1, 3)) + + new_context_layer_shape = tuple(context_layer.shape.with_tile_padding())[:-2] + (num_heads * attention_head_size,) + context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT) + context_layer = ttnn.reshape( + ttnn.from_device(context_layer), + ( + 1, + new_context_layer_shape[0], + new_context_layer_shape[1], + new_context_layer_shape[2], + ), + ) + context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT) + context_layer = ttnn.to_device(context_layer, device=device) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +def attention( + config, + dim, + num_heads, + window_size, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=None, + parameters=None, + device=None, + relative_position_bias=None, +): + self_output = self_attention( + config, + dim, + num_heads, + window_size, + hidden_states, + attention_mask, + head_mask, + output_attentions, + parameters, + device, + relative_position_bias=relative_position_bias, + ) + + attention_output = ttnn.linear( + self_output[0], + parameters.output.weight, + bias=parameters.output.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + outputs = (attention_output,) + self_output[1:] + return outputs + + +def maybe_pad(hidden_states, height, width, window_size): + pad_right = (window_size - width % window_size) % window_size + pad_bottom = (window_size - height % window_size) % window_size + pad_values = [(0, 0), (0, pad_right), (0, pad_bottom)] + hidden_states = ttnn.pad(hidden_states, pad_values, value=0) + return hidden_states, pad_values + + +def swin_intermediate(config, dim, hidden_states, parameter, device): + return ttnn.linear( + hidden_states, + parameter.dense.weight, + bias=parameter.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + activation="gelu", + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + +def swin_layer( + config, + dim, + input_resolution, + num_heads, + shift_size, + hidden_states, + input_dimensions, + head_mask=None, + output_attentions=None, + parameters=None, + device=None, + relative_position_bias=None, + attn_mask=None, +): + height, width = input_dimensions + window_size = config.window_size + if min(input_dimensions) < config.window_size: + shift_size = 0 + window_size = min(input_dimensions) + + batch_size, _, channels = hidden_states.shape + shortcut = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + epsilon=1e-05, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + hidden_states = ttnn.to_layout(hidden_states, layout=ttnn.ROW_MAJOR_LAYOUT) + hidden_states = ttnn.reshape(hidden_states, (batch_size, height, width, channels)) + hidden_states, pad_values = maybe_pad(hidden_states, height, width, window_size) + _, height_pad, width_pad, _ = hidden_states.shape + + if shift_size > 0: + shifted_hidden_states = torch.roll(ttnn.to_torch(hidden_states), shifts=(-shift_size, -shift_size), dims=(1, 2)) + shifted_hidden_states = ttnn.from_torch( + shifted_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + else: + shifted_hidden_states = hidden_states + + hidden_states_windows = window_partition(shifted_hidden_states, window_size, device) + hidden_states_windows = ttnn.reshape(hidden_states_windows, (-1, window_size * window_size, channels)) + hidden_states_windows = ttnn.to_layout(hidden_states_windows, layout=ttnn.TILE_LAYOUT) + hidden_states_windows = ttnn.to_device(hidden_states_windows, device=device) + attention_outputs = attention( + config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + hidden_states=hidden_states_windows, + attention_mask=attn_mask, + head_mask=head_mask, + output_attentions=output_attentions, + parameters=parameters.attention, + device=device, + relative_position_bias=relative_position_bias, + ) + + attention_output = attention_outputs[0] + + attention_output = ttnn.to_layout(attention_output, layout=ttnn.ROW_MAJOR_LAYOUT) + attention_windows = ttnn.reshape(attention_output, (-1, window_size, window_size, channels)) + + shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad, device) + + if shift_size > 0: + attention_windows = torch.roll(ttnn.to_torch(shifted_windows), shifts=(shift_size, shift_size), dims=(1, 2)) + attention_windows = ttnn.from_torch(attention_windows, dtype=ttnn.bfloat16, device=device) + else: + attention_windows = shifted_windows + + was_padded = pad_values[1][1] > 0 or pad_values[2][1] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :] + attention_windows = ttnn.reshape(attention_windows, (batch_size, height * width, channels)) + attention_windows = ttnn.to_layout(attention_windows, layout=ttnn.TILE_LAYOUT) + hidden_states = ttnn.add(shortcut, attention_windows) + + layer_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + epsilon=1e-05, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + layer_output = ttnn.linear( + layer_output, + parameters.intermediate.dense.weight, + bias=parameters.intermediate.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + activation="gelu", + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + layer_output = ttnn.linear( + layer_output, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + layer_output = hidden_states + layer_output + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +def patch_merge_pad(input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = ttnn.pad(input_feature, pad_values) + + return input_feature + + +def patch_merging(config, input_resolution, dim, input_feature, input_dimensions, parameter, device): + height, width = input_dimensions + batch_size, dim, num_channels = input_feature.shape + input_feature = ttnn.to_layout(input_feature, layout=ttnn.ROW_MAJOR_LAYOUT) + input_feature = ttnn.reshape(input_feature, (batch_size, height, width, num_channels)) + input_feature = patch_merge_pad(input_feature, height, width) + + input_feature_0 = input_feature[:, 0::2, 0::2, :] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + + input_feature = ttnn.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + + input_feature = ttnn.to_layout(input_feature, layout=ttnn.ROW_MAJOR_LAYOUT) + + input_feature = ttnn.reshape(input_feature, (batch_size, -1, 4 * num_channels)) + input_feature = ttnn.to_layout(input_feature, layout=ttnn.TILE_LAYOUT) + + input_feature = ttnn.layer_norm( + input_feature, + weight=parameter.downsample.norm.weight, + bias=parameter.downsample.norm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + input_feature = ttnn.linear( + input_feature, + parameter.downsample.reduction.weight, + bias=None, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + return input_feature + + +def swin_stage( + config, + dim, + input_resolution, + hidden_states, + input_dimensions, + depth, + layer_head_mask=None, + output_attention=None, + num_heads=None, + downsample=None, + parameter=None, + device=None, + relative_position_bias=None, + attn_mask_list=None, +): + height, width = input_dimensions + + # for block in parameter.blocks: + for i in range(depth): + layer_head_mask = None + layer_outputs = swin_layer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if i % 2 == 0 else config.window_size // 2, + hidden_states=hidden_states, + input_dimensions=input_dimensions, + head_mask=layer_head_mask, + output_attentions=output_attention, + parameters=parameter.blocks[i], + device=device, + relative_position_bias=relative_position_bias[i], + attn_mask=attn_mask_list[i], + ) + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if downsample: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = patch_merging( + config, input_resolution, dim, hidden_states_before_downsampling, input_dimensions, parameter, device + ) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = ( + hidden_states, + hidden_states_before_downsampling, + output_dimensions, + ) + + return stage_outputs + + +def encoder( + config, + hidden_state, + input_dimension, + head_mask=None, + output_attention=None, + output_hidden_states=None, + parameters=None, + device=None, + bias_table=None, + attention_mask_list=None, +): + if output_hidden_states: + batch_size, _, hidden_size = hidden_state.shape + image_size = (config.image_size, config.image_size) + patch_size = (config.patch_size, config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + grid_size = ( + image_size[0] // patch_size[0], + image_size[1] // patch_size[1], + ) + for i_layer in range(len(config.depths)): + layer_head_mask = None + layer_outputs = swin_stage( + config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=( + grid_size[0] // (2**i_layer), + grid_size[1] // (2**i_layer), + ), + hidden_states=hidden_state, + input_dimensions=input_dimension, + layer_head_mask=layer_head_mask, + output_attention=output_attention, + num_heads=config.num_heads[i_layer], + depth=config.depths[i_layer], + downsample=True if (i_layer < len(config.depths) - 1) else False, + parameter=parameters.layers[i_layer], + device=device, + relative_position_bias=bias_table[i_layer], + attn_mask_list=attention_mask_list[i_layer], + ) + hidden_state = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + input_dimension = (output_dimensions[-2], output_dimensions[-1]) + return hidden_state + + +def swin( + config, + pixel_values, + bool_masked_pos=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + parameters=None, + device=None, + bias_table=None, + attention_mask_list=None, +): + output_attentions = output_attentions if output_attentions is not None else config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else config.output_hidden_states + head_mask = [ + None, + ] * len(config.depths) + + embedding_output, input_dimensions = embeddings( + config=config, pixel_values=pixel_values, bool_masked_pos=bool_masked_pos, parameters=parameters, device=device + ) + sequence_output = encoder( + config, + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attention=output_attentions, + output_hidden_states=output_hidden_states, + parameters=parameters.encoder, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + sequence_output = ttnn.to_device(sequence_output, device=device) + sequence_output = ttnn.layer_norm( + sequence_output, + weight=parameters.layernorm.weight, + bias=parameters.layernorm.bias, + epsilon=config.layer_norm_eps, + # memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + pooler = nn.AdaptiveAvgPool1d(1) + sequence_output_1 = ttnn.to_torch(sequence_output) + + pooled_output = pooler(sequence_output_1.transpose(1, 2)) + pooled_output = ttnn.from_torch(pooled_output, dtype=ttnn.bfloat16) + pooled_output = ttnn.reshape( + pooled_output, (pooled_output.shape[0], pooled_output.shape[1] * pooled_output.shape[2]) + ) + + return sequence_output, pooled_output + + +def swin_for_image_classification( + config, + pixel_values, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + bias_table=None, + attention_mask_list=None, + *, + parameters, + device, +): + outputs = swin( + config=config, + pixel_values=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + parameters=parameters.swin, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + pooled_output = outputs[1] + + pooled_output = ttnn.to_layout(pooled_output, layout=ttnn.TILE_LAYOUT) + pooled_output = ttnn.to_device(pooled_output, device=device) + + logits = ttnn.linear( + pooled_output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + return logits + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, nn.Conv2d): + weight = model.weight + bias = model.bias + while weight.dim() < 4: + weight = weight.unsqueeze(0) + while bias.dim() < 4: + bias = bias.unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.bfloat16) + + if hasattr(model, "self"): + qkv_weight = torch.cat( + [ + model.self.query.weight, + model.self.key.weight, + model.self.value.weight, + ], + dim=0, + ) + qkv_bias = torch.cat( + [model.self.query.bias, model.self.key.bias, model.self.value.bias], + dim=0, + ) + output_weight = model.output.dense.weight + output_bias = model.output.dense.bias + parameters = {"query_key_value": {}, "relative_position_bias_table": {}, "output": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat16) + parameters["output"]["weight"] = preprocess_linear_weight(output_weight, dtype=ttnn.bfloat16) + parameters["output"]["bias"] = preprocess_linear_bias(output_bias, dtype=ttnn.bfloat16) + parameters["relative_position_bias_table"] = ttnn.from_torch( + model.self.relative_position_bias_table, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index c251fa4ccb3..702e4357747 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -25,6 +25,8 @@ run_perf_models_other() { env pytest models/demos/distilbert/tests/test_perf_distilbert.py -m $test_marker + env pytest models/demos/swin/tests/test_perf_swin.py -m $test_marker + env pytest -n auto tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker @@ -91,6 +93,8 @@ run_device_perf_models() { env pytest models/demos/distilbert/tests -m $test_marker + env pytest models/demos/swin/tests/ -m $test_marker + env pytest models/demos/vgg/tests/ -m $test_marker env pytest models/demos/convnet_mnist/tests/ -m $test_marker diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 0994d8fe24b..f2ad580f4fb 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -43,6 +43,9 @@ run_common_func_tests() { # ConvNet Mnist pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + # Swin + pytest --disable-warnings models/demos/swin/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/swin/test_ttnn_swin.py b/tests/ttnn/integration_tests/swin/test_ttnn_swin.py new file mode 100644 index 00000000000..b53041a043b --- /dev/null +++ b/tests/ttnn/integration_tests/swin/test_ttnn_swin.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from transformers import SwinModel + +from models.demos.swin.tt import ttnn_optimized_swin +from transformers import SwinForImageClassification as HF_SwinForImageClassification +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.demos.swin.tt.swin_utils import get_relative_position, get_attn_mask +from transformers import AutoFeatureExtractor +from models.utility_functions import is_grayskull, is_wormhole_b0 + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_patch_embedding(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + # Torch swinpatchembedding + torch_model = model.embeddings.patch_embeddings + pixel_values = torch.rand(batch_size, 3, 224, 224) + + torch_output = torch_model(pixel_values) + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_pixel = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + + tt_output = ttnn_optimized_swin.patch_embeddings(config, tt_pixel, parameters, device) + tt_output_tensor = ttnn.to_torch(tt_output[0]) + assert_with_pcc(torch_output[0], tt_output_tensor.squeeze(0), 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_embedding(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + pixel_values = torch.rand(batch_size, 3, 224, 224) + torch_model = model.embeddings + torch_output = torch_model(pixel_values) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_pixel = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + image_size = (config.image_size, config.image_size) + patch_size = (config.patch_size, config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + position_ids = torch.zeros(1, num_patches + 1, config.embed_dim) + + tt_position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, layout=ttnn.TILE_LAYOUT) + tt_output = ttnn_optimized_swin.embeddings( + config, + pixel_values=tt_pixel, + position_embeddings=tt_position_ids, + parameters=parameters, + device=device, + ) + tt_output_tensor = ttnn.to_torch(tt_output[0]) + assert_with_pcc(torch_output[0], tt_output_tensor, 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_self_attention(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model.encoder.layers[0].blocks[0].attention.self + num_heads, window_size, dim = 3, 7, 96 + + hidden_states = torch.rand(64, 49, 96) + attention_mask = torch.ones(64, 49, 49) + + torch_output = torch_model(hidden_states, attention_mask) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_hidden_states = ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + tt_attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + tt_output = ttnn_optimized_swin.self_attention( + model.config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + hidden_states=tt_hidden_states, + attention_mask=tt_attention_mask, + parameters=parameters.encoder.layers[0].blocks[0].attention, + device=device, + relative_position_bias=bias_table[0][0], + ) + + tt_output_tensor = ttnn.to_torch(tt_output[0]).squeeze(0) + assert_with_pcc(torch_output[0], tt_output_tensor, 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_attention(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model.encoder.layers[0].blocks[0].attention + num_heads, window_size, dim = 3, 7, 96 + + hidden_states = torch.rand(64, 49, 96) + + torch_output = torch_model(hidden_states) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_hidden_states = ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + tt_output = ttnn_optimized_swin.attention( + model.config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + hidden_states=tt_hidden_states, + parameters=parameters.encoder.layers[0].blocks[0].attention, + device=device, + relative_position_bias=bias_table[0][0], + ) + tt_output_tensor = ttnn.to_torch(tt_output[0]).squeeze(0) + assert_with_pcc(torch_output[0], tt_output_tensor, 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_layer(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model.encoder.layers[0].blocks[0] + num_heads, window_size, dim = 3, 7, 96 + + hidden_states = torch.rand(1, 3136, 96) + input_resolution = (56, 56) + shift_size = 0 + + torch_output = torch_model(hidden_states, input_resolution) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_hidden_states = ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + attention_mask_list = get_attn_mask(model.config, device) + + tt_output = ttnn_optimized_swin.swin_layer( + model.config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=shift_size, + hidden_states=tt_hidden_states, + input_dimensions=(56, 56), + parameters=parameters.encoder.layers[0].blocks[0], + device=device, + relative_position_bias=bias_table[0][0], + attn_mask=attention_mask_list[0][0], + ) + tt_output_tensor = ttnn.to_torch(tt_output[0]) + + assert_with_pcc(torch_output[0], tt_output_tensor, 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_stage(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model.encoder.layers[0] + dim = 96 + input_resolution = (56, 56) + depth = 2 + num_heads = 3 + + hidden_states = torch.rand(1, 3136, 96) + input_resolution = (56, 56) + shift_size = 0 + input_dimensions = (56, 56) + + torch_output = torch_model(hidden_states, input_resolution) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_hidden_states = ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + attention_mask_list = get_attn_mask(model.config, device) + tt_output = ttnn_optimized_swin.swin_stage( + model.config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + hidden_states=tt_hidden_states, + input_dimensions=(56, 56), + depth=config.depths[0], + downsample=True, + parameter=parameters.encoder.layers[0], + device=device, + relative_position_bias=bias_table[0], + attn_mask_list=attention_mask_list[0], + ) + tt_output_tensor = ttnn.to_torch(tt_output[0]) + assert_with_pcc(torch_output[0], tt_output_tensor, 0.99) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_encoder(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model.encoder + dim = 96 + input_resolution = (56, 56) + num_heads = 3 + + hidden_states = torch.rand(1, 3136, 96) + + torch_output = torch_model(hidden_states, input_resolution) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_hidden_states = ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + attention_mask_list = get_attn_mask(model.config, device) + + tt_output = ttnn_optimized_swin.encoder( + model.config, + hidden_state=tt_hidden_states, + input_dimension=(56, 56), + parameters=parameters.encoder, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + + tt_output_tensor = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output[0], tt_output_tensor, 0.98) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_model(device, model_name, batch_size, reset_seeds): + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").eval() + config = model.config + + torch_model = model + + pixel_values = torch.rand(8, 3, 224, 224) + + torch_output = torch_model(pixel_values) + + parameters = preprocess_model_parameters( + model_name=model, + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + tt_pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters, device) + attention_mask_list = get_attn_mask(model.config, device) + tt_output = ttnn_optimized_swin.swin( + model.config, + pixel_values=tt_pixel_values, + parameters=parameters, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + + tt_sequence_output = ttnn.to_torch(tt_output[0]) + tt_pooled_output = ttnn.to_torch(tt_output[1]) + + assert_with_pcc(torch_output[0], tt_sequence_output, 0.72) + assert_with_pcc(torch_output[1], tt_pooled_output, 0.90) + + +@pytest.mark.parametrize("model_name", ["microsoft/swin-tiny-patch4-window7-224"]) +@pytest.mark.parametrize("batch_size", [8]) +def test_swin_for_image_classification(device, model_name, batch_size, reset_seeds): + model = HF_SwinForImageClassification.from_pretrained(model_name) + + config = model.config + torch_model = model + + pixel_values = torch.rand(batch_size, 3, 224, 224) + + torch_output = torch_model(pixel_values) + + parameters = preprocess_model_parameters( + model_name="ttnn_optimized_swin", + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, + device=device, + ) + + tt_pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + bias_table = get_relative_position(model.config, parameters.swin, device) + attention_mask_list = get_attn_mask(model.config, device) + + tt_output = ttnn_optimized_swin.swin_for_image_classification( + model.config, + pixel_values=tt_pixel_values, + parameters=parameters, + device=device, + bias_table=bias_table, + attention_mask_list=attention_mask_list, + ) + + tt_output = ttnn.to_torch(tt_output) + pcc = 0.91 + if is_grayskull: + pcc = 0.83 + assert_with_pcc(torch_output.logits, tt_output, pcc)