Skip to content

Commit

Permalink
#13403: Add ttnn support for swin model
Browse files Browse the repository at this point in the history
  • Loading branch information
Sudharsan-V committed Nov 4, 2024
1 parent d46ba83 commit 1536046
Show file tree
Hide file tree
Showing 9 changed files with 1,536 additions and 0 deletions.
20 changes: 20 additions & 0 deletions models/demos/swin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## Swin Model

# Platforms:
GS E150, WH N150, WH N300

## Introduction
The Swin Transformer is a variant of the Vision Transformer that generates hierarchical feature maps by progressively merging image patches in its deeper layers. It achieves linear computational complexity relative to the input image size by restricting self-attention calculations to local windows. This design allows it to function as a versatile backbone for tasks like image classification and dense recognition. In contrast, earlier Vision Transformers generate feature maps at a single low resolution and have quadratic computational complexity, as they compute self-attention across the entire image.

# Details
The entry point to swin model is swin_for_image_classification in `models/demos/swin/tt/tt/ttnn_optimized_swin.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `microsoft/swin-tiny-patch4-window7-224` version from huggingface as our reference.


## Batch size: 8

Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8

Use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8-5-device_params0]` to run the ttnn_optimized_swin demo.


If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8-<n_iterations>-device_params0]`
112 changes: 112 additions & 0 deletions models/demos/swin/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
from loguru import logger
import pytest


from models.utility_functions import (
disable_compilation_reports,
disable_persistent_kernel_cache,
enable_persistent_kernel_cache,
profiler,
)
import ttnn

from models.demos.swin.demo_utils import get_data_loader, get_batch, preprocess
from loguru import logger
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.swin.tt import ttnn_optimized_swin
from transformers import SwinForImageClassification as HF_SwinForImageClassification
from models.demos.swin.tt.swin_utils import get_relative_position, get_attn_mask


def run_swin_imagenet_inference(
batch_size,
iterations,
imagenet_label_dict,
model_location_generator,
device,
):
disable_persistent_kernel_cache()
disable_compilation_reports()
profiler.clear()

# Setup model
torch_model = HF_SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
config = torch_model.config
torch_model.to(torch.bfloat16)
torch_model.eval()

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model,
model_name=torch_model,
device=device,
convert_to_ttnn=lambda *_: True,
custom_preprocessor=ttnn_optimized_swin.custom_preprocessor,
)

# load inputs
logger.info("ImageNet-1k validation Dataset")
input_loc = str(model_location_generator("ImageNet_data"))
data_loader = get_data_loader(input_loc, batch_size, iterations)

bias_table = get_relative_position(torch_model.config, parameters.swin, device)
attention_mask_list = get_attn_mask(torch_model.config, device)

# load ImageNet batch by batch
# and run inference
correct = 0
torch_ttnn_correct = 0
torch_correct = 0
for iter in range(iterations):
predictions = []
torch_predictions = []
inputs, labels = get_batch(data_loader)
torch_outputs = torch_model(inputs)
tt_batched_input_tensor = ttnn.from_torch(inputs, ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
tt_output = ttnn_optimized_swin.swin_for_image_classification(
torch_model.config,
tt_batched_input_tensor,
parameters=parameters,
device=device,
bias_table=bias_table,
attention_mask_list=attention_mask_list,
)
tt_output = ttnn.to_torch(tt_output)
prediction = tt_output.argmax(dim=-1)
torch_prediction = torch_outputs[0].argmax(dim=-1)
for i in range(batch_size):
predictions.append(imagenet_label_dict[prediction[i].item()])
torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()])
logger.info(
f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}"
)
if imagenet_label_dict[labels[i]] == predictions[-1]:
correct += 1
if imagenet_label_dict[labels[i]] == torch_predictions[-1]:
torch_correct += 1
if predictions[-1] == torch_predictions[-1]:
torch_ttnn_correct += 1

del tt_output, tt_batched_input_tensor, inputs, labels, predictions
accuracy = correct / (batch_size * iterations)
torch_accuracy = torch_correct / (batch_size * iterations)
torch_ttnn_accuracy = torch_ttnn_correct / (batch_size * iterations)

logger.info(f"Model Swin for Image Classification")
logger.info(f"TTNN Accuracy for {batch_size}x{iterations} inputs: {accuracy}")
logger.info(f"Torch Accuracy for {batch_size}x{iterations} inputs: {torch_accuracy}")
logger.info(f"Torch vs TTNN Accuracy for {batch_size}x{iterations} inputs: {torch_ttnn_accuracy}")


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, iterations",
((8, 5),),
)
def test_demo_imagenet(batch_size, iterations, imagenet_label_dict, model_location_generator, device):
run_swin_imagenet_inference(batch_size, iterations, imagenet_label_dict, model_location_generator, device)
130 changes: 130 additions & 0 deletions models/demos/swin/demo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from PIL import Image
import torch
import os
import glob
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES
from datasets import load_dataset
from torchvision import models
from PIL import Image
import torchvision.transforms as transforms
import torch


class InputExample(object):
def __init__(self, image, label=None):
self.image = image
self.label = label


def get_input(image_path):
img = Image.open(image_path)
return img


def get_label(image_path):
_, image_name = image_path.rsplit("/", 1)
image_name_exact, _ = image_name.rsplit(".", 1)
_, label_id = image_name_exact.rsplit("_", 1)
label = list(IMAGENET2012_CLASSES).index(label_id)
return label


preprocess = transforms.Compose(
[
transforms.Resize(256), # Resize the shorter side to 256 pixels
transforms.CenterCrop(224), # Crop the center to 224x224 pixels
transforms.ToTensor(), # Convert the image to a tensor
transforms.Normalize( # Normalize using ImageNet's mean and std
mean=[0.485, 0.456, 0.406], # These are the mean values for each channel
std=[0.229, 0.224, 0.225], # These are the std values for each channel
),
]
)


def get_batch(data_loader):
loaded_images = next(data_loader)
images = None
labels = []
transform = transforms.ToTensor()
resize_transform = transforms.Resize((224, 224))
for image in loaded_images:
img = image.image
labels.append(image.label)
if img.mode == "L":
img = img.convert(mode="RGB")

img = preprocess(img)
img = img.to(torch.bfloat16)
img = img.unsqueeze(0)
if images is None:
images = img
else:
images = torch.cat((images, img), dim=0)
return images, labels


def get_data_loader(input_loc, batch_size, iterations):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = glob.glob(data_path)

def loader():
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
if len(examples) == batch_size:
yield examples
del examples
examples = []

def loader_hf():
examples = []
for f1 in files:
examples.append(
InputExample(
image=f1["image"],
label=f1["label"],
)
)
if len(examples) == batch_size:
yield examples
del examples
examples = []

if len(files) == 0:
files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True))
files = []
sample_count = batch_size * iterations
for _ in range(sample_count):
files.append(next(files_raw))
del files_raw
return loader_hf()

return loader()


def get_data(input_loc):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = sorted(glob.glob(data_path))
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
image_examples = examples

return image_examples
Loading

0 comments on commit 1536046

Please sign in to comment.