-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#13403: Add ttnn support for swin model
- Loading branch information
1 parent
d46ba83
commit 1536046
Showing
9 changed files
with
1,536 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
## Swin Model | ||
|
||
# Platforms: | ||
GS E150, WH N150, WH N300 | ||
|
||
## Introduction | ||
The Swin Transformer is a variant of the Vision Transformer that generates hierarchical feature maps by progressively merging image patches in its deeper layers. It achieves linear computational complexity relative to the input image size by restricting self-attention calculations to local windows. This design allows it to function as a versatile backbone for tasks like image classification and dense recognition. In contrast, earlier Vision Transformers generate feature maps at a single low resolution and have quadratic computational complexity, as they compute self-attention across the entire image. | ||
|
||
# Details | ||
The entry point to swin model is swin_for_image_classification in `models/demos/swin/tt/tt/ttnn_optimized_swin.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `microsoft/swin-tiny-patch4-window7-224` version from huggingface as our reference. | ||
|
||
|
||
## Batch size: 8 | ||
|
||
Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8 | ||
|
||
Use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8-5-device_params0]` to run the ttnn_optimized_swin demo. | ||
|
||
|
||
If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/swin/demo/demo.py::test_demo_imagenet[8-<n_iterations>-device_params0]` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import torch | ||
from loguru import logger | ||
import pytest | ||
|
||
|
||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
enable_persistent_kernel_cache, | ||
profiler, | ||
) | ||
import ttnn | ||
|
||
from models.demos.swin.demo_utils import get_data_loader, get_batch, preprocess | ||
from loguru import logger | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.demos.swin.tt import ttnn_optimized_swin | ||
from transformers import SwinForImageClassification as HF_SwinForImageClassification | ||
from models.demos.swin.tt.swin_utils import get_relative_position, get_attn_mask | ||
|
||
|
||
def run_swin_imagenet_inference( | ||
batch_size, | ||
iterations, | ||
imagenet_label_dict, | ||
model_location_generator, | ||
device, | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
profiler.clear() | ||
|
||
# Setup model | ||
torch_model = HF_SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") | ||
config = torch_model.config | ||
torch_model.to(torch.bfloat16) | ||
torch_model.eval() | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: torch_model, | ||
model_name=torch_model, | ||
device=device, | ||
convert_to_ttnn=lambda *_: True, | ||
custom_preprocessor=ttnn_optimized_swin.custom_preprocessor, | ||
) | ||
|
||
# load inputs | ||
logger.info("ImageNet-1k validation Dataset") | ||
input_loc = str(model_location_generator("ImageNet_data")) | ||
data_loader = get_data_loader(input_loc, batch_size, iterations) | ||
|
||
bias_table = get_relative_position(torch_model.config, parameters.swin, device) | ||
attention_mask_list = get_attn_mask(torch_model.config, device) | ||
|
||
# load ImageNet batch by batch | ||
# and run inference | ||
correct = 0 | ||
torch_ttnn_correct = 0 | ||
torch_correct = 0 | ||
for iter in range(iterations): | ||
predictions = [] | ||
torch_predictions = [] | ||
inputs, labels = get_batch(data_loader) | ||
torch_outputs = torch_model(inputs) | ||
tt_batched_input_tensor = ttnn.from_torch(inputs, ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) | ||
tt_output = ttnn_optimized_swin.swin_for_image_classification( | ||
torch_model.config, | ||
tt_batched_input_tensor, | ||
parameters=parameters, | ||
device=device, | ||
bias_table=bias_table, | ||
attention_mask_list=attention_mask_list, | ||
) | ||
tt_output = ttnn.to_torch(tt_output) | ||
prediction = tt_output.argmax(dim=-1) | ||
torch_prediction = torch_outputs[0].argmax(dim=-1) | ||
for i in range(batch_size): | ||
predictions.append(imagenet_label_dict[prediction[i].item()]) | ||
torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()]) | ||
logger.info( | ||
f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}" | ||
) | ||
if imagenet_label_dict[labels[i]] == predictions[-1]: | ||
correct += 1 | ||
if imagenet_label_dict[labels[i]] == torch_predictions[-1]: | ||
torch_correct += 1 | ||
if predictions[-1] == torch_predictions[-1]: | ||
torch_ttnn_correct += 1 | ||
|
||
del tt_output, tt_batched_input_tensor, inputs, labels, predictions | ||
accuracy = correct / (batch_size * iterations) | ||
torch_accuracy = torch_correct / (batch_size * iterations) | ||
torch_ttnn_accuracy = torch_ttnn_correct / (batch_size * iterations) | ||
|
||
logger.info(f"Model Swin for Image Classification") | ||
logger.info(f"TTNN Accuracy for {batch_size}x{iterations} inputs: {accuracy}") | ||
logger.info(f"Torch Accuracy for {batch_size}x{iterations} inputs: {torch_accuracy}") | ||
logger.info(f"Torch vs TTNN Accuracy for {batch_size}x{iterations} inputs: {torch_ttnn_accuracy}") | ||
|
||
|
||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) | ||
@pytest.mark.parametrize( | ||
"batch_size, iterations", | ||
((8, 5),), | ||
) | ||
def test_demo_imagenet(batch_size, iterations, imagenet_label_dict, model_location_generator, device): | ||
run_swin_imagenet_inference(batch_size, iterations, imagenet_label_dict, model_location_generator, device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from PIL import Image | ||
import torch | ||
import os | ||
import glob | ||
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES | ||
from datasets import load_dataset | ||
from torchvision import models | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
import torch | ||
|
||
|
||
class InputExample(object): | ||
def __init__(self, image, label=None): | ||
self.image = image | ||
self.label = label | ||
|
||
|
||
def get_input(image_path): | ||
img = Image.open(image_path) | ||
return img | ||
|
||
|
||
def get_label(image_path): | ||
_, image_name = image_path.rsplit("/", 1) | ||
image_name_exact, _ = image_name.rsplit(".", 1) | ||
_, label_id = image_name_exact.rsplit("_", 1) | ||
label = list(IMAGENET2012_CLASSES).index(label_id) | ||
return label | ||
|
||
|
||
preprocess = transforms.Compose( | ||
[ | ||
transforms.Resize(256), # Resize the shorter side to 256 pixels | ||
transforms.CenterCrop(224), # Crop the center to 224x224 pixels | ||
transforms.ToTensor(), # Convert the image to a tensor | ||
transforms.Normalize( # Normalize using ImageNet's mean and std | ||
mean=[0.485, 0.456, 0.406], # These are the mean values for each channel | ||
std=[0.229, 0.224, 0.225], # These are the std values for each channel | ||
), | ||
] | ||
) | ||
|
||
|
||
def get_batch(data_loader): | ||
loaded_images = next(data_loader) | ||
images = None | ||
labels = [] | ||
transform = transforms.ToTensor() | ||
resize_transform = transforms.Resize((224, 224)) | ||
for image in loaded_images: | ||
img = image.image | ||
labels.append(image.label) | ||
if img.mode == "L": | ||
img = img.convert(mode="RGB") | ||
|
||
img = preprocess(img) | ||
img = img.to(torch.bfloat16) | ||
img = img.unsqueeze(0) | ||
if images is None: | ||
images = img | ||
else: | ||
images = torch.cat((images, img), dim=0) | ||
return images, labels | ||
|
||
|
||
def get_data_loader(input_loc, batch_size, iterations): | ||
img_dir = input_loc + "/" | ||
data_path = os.path.join(img_dir, "*G") | ||
files = glob.glob(data_path) | ||
|
||
def loader(): | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=get_input(f1), | ||
label=get_label(f1), | ||
) | ||
) | ||
if len(examples) == batch_size: | ||
yield examples | ||
del examples | ||
examples = [] | ||
|
||
def loader_hf(): | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=f1["image"], | ||
label=f1["label"], | ||
) | ||
) | ||
if len(examples) == batch_size: | ||
yield examples | ||
del examples | ||
examples = [] | ||
|
||
if len(files) == 0: | ||
files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True)) | ||
files = [] | ||
sample_count = batch_size * iterations | ||
for _ in range(sample_count): | ||
files.append(next(files_raw)) | ||
del files_raw | ||
return loader_hf() | ||
|
||
return loader() | ||
|
||
|
||
def get_data(input_loc): | ||
img_dir = input_loc + "/" | ||
data_path = os.path.join(img_dir, "*G") | ||
files = sorted(glob.glob(data_path)) | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=get_input(f1), | ||
label=get_label(f1), | ||
) | ||
) | ||
image_examples = examples | ||
|
||
return image_examples |
Oops, something went wrong.