Skip to content

Commit

Permalink
#4346: Uplift and tilize VGG model
Browse files Browse the repository at this point in the history
  • Loading branch information
jayasuryamaganuru authored and Sudharsan-V committed Jan 5, 2024
1 parent ce89bbf commit 5d626ec
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 144 deletions.
13 changes: 6 additions & 7 deletions models/experimental/vgg/tests/test_perf_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
disable_persistent_kernel_cache,
enable_persistent_kernel_cache,
Profiler,
comp_pcc,
unpad_from_zero,
torch_to_tt_tensor,
)
from models.perf.perf_utils import prep_perf_report

Expand All @@ -34,14 +37,10 @@ def run_perf_vgg(imagenet_sample_input, expected_inference_time, expected_compil

image = imagenet_sample_input

tt_image = tt_lib.tensor.Tensor(
image.reshape(-1).tolist(),
get_shape(image.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
tt_image = torch_to_tt_tensor(image, device)

tt_vgg = vgg16(device, disable_conv_on_tt_device=True)
cache_path = "/mnt/MLPerf/tt_dnn-models/tt/VGG/vgg16/"
tt_vgg = vgg16(device, disable_conv_on_tt_device=True, tt_cache_path=cache_path)

torch_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
torch_vgg.eval()
Expand Down
27 changes: 13 additions & 14 deletions models/experimental/vgg/tests/test_vgg11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,46 @@

import torch
import pytest
import tt_lib

from torchvision import models
from loguru import logger

import tt_lib

from models.experimental.vgg.tt.vgg import vgg11
from models.experimental.vgg.vgg_utils import get_shape
from models.utility_functions import comp_pcc
from models.utility_functions import comp_pcc, torch_to_tt_tensor, unpad_from_zero
from models.utility_functions import comp_allclose

_batch_size = 1


@pytest.mark.parametrize(
"dtype",
((tt_lib.tensor.DataType.BFLOAT16),),
)
@pytest.mark.parametrize(
"pcc",
((0.99),),
)
def test_vgg11_inference(device, pcc, imagenet_sample_input):
def test_vgg11_inference(device, pcc, imagenet_sample_input, model_location_generator, dtype):
image = imagenet_sample_input

batch_size = _batch_size
with torch.no_grad():
torch_vgg = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
torch_vgg.eval()

cache_path = "/mnt/MLPerf/tt_dnn-models/tt/VGG/vgg11/"
# TODO: enable conv on tt device after adding fast dtx transform
tt_vgg = vgg11(device, disable_conv_on_tt_device=True)
tt_vgg = vgg11(device, disable_conv_on_tt_device=True, tt_cache_path=cache_path, tt_dtype=dtype)

torch_output = torch_vgg(image).unsqueeze(1).unsqueeze(1)
tt_image = tt_lib.tensor.Tensor(
image.reshape(-1).tolist(),
get_shape(image.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
tt_image = torch_to_tt_tensor(image, device)

tt_output = tt_vgg(tt_image)

tt_output = unpad_from_zero(tt_output, torch_output.shape)
tt_output = tt_output.cpu()
tt_output = torch.Tensor(tt_output.to_torch())

logger.info(comp_allclose(torch_output, tt_output))
pcc_passing, pcc_output = comp_pcc(torch_output, tt_output, pcc)
logger.info(f"Output {pcc_output}")
assert pcc_passing, f"Model output does not meet PCC requirement {pcc}."
28 changes: 13 additions & 15 deletions models/experimental/vgg/tests/test_vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@

import torch
import pytest
import tt_lib

from torchvision import models
from loguru import logger

import tt_lib
from models.utility_functions import comp_allclose

from models.experimental.vgg.tt.vgg import vgg16
from models.experimental.vgg.vgg_utils import get_shape
from models.utility_functions import comp_pcc

from models.utility_functions import comp_pcc, unpad_from_zero, torch_to_tt_tensor

_batch_size = 1


@pytest.mark.parametrize(
"dtype",
((tt_lib.tensor.DataType.BFLOAT16),),
)
@pytest.mark.parametrize(
"pcc",
((0.99),),
)
def test_vgg16_inference(device, pcc, imagenet_sample_input):
def test_vgg16_inference(device, pcc, imagenet_sample_input, model_location_generator, dtype):
image = imagenet_sample_input

batch_size = _batch_size
Expand All @@ -32,20 +34,16 @@ def test_vgg16_inference(device, pcc, imagenet_sample_input):
torch_output = torch_vgg(image).unsqueeze(1).unsqueeze(1)

# TODO: enable conv on tt device after adding fast dtx transform
tt_vgg = vgg16(device, disable_conv_on_tt_device=True)
cache_path = "/mnt/MLPerf/tt_dnn-models/tt/VGG/vgg16/"
tt_vgg = vgg16(device, disable_conv_on_tt_device=True, tt_cache_path=cache_path)

tt_image = tt_lib.tensor.Tensor(
image.reshape(-1).tolist(),
get_shape(image.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
tt_image = torch_to_tt_tensor(image, device)

tt_output = tt_vgg(tt_image)

tt_output = unpad_from_zero(tt_output, torch_output.shape)
tt_output = tt_output.cpu()
tt_output = torch.Tensor(tt_output.to_torch())

logger.info(comp_allclose(torch_output, tt_output))
pcc_passing, pcc_output = comp_pcc(torch_output, tt_output, pcc)
logger.info(f"Output {pcc_output}")
assert pcc_passing, f"Model output does not meet PCC requirement {pcc}."
113 changes: 40 additions & 73 deletions models/experimental/vgg/tt/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import tt_lib

from tt_lib.fallback_ops import fallback_ops
from models.experimental.vgg.vgg_utils import get_shape
from models.experimental.vgg.tt.vgg_helper_funcs import tt_linear
from models.experimental.vgg.vgg_utils import format_tensor
from models.utility_functions import (
is_conv_supported_on_device,
run_conv_on_device_wrapper,
torch_to_tt_tensor_rm,
torch_to_tt_tensor,
)

from models.helper_funcs import Linear as TtLinear

num_classes = 1000

Expand All @@ -31,71 +32,36 @@ def __init__(
init_weights: bool = True,
dropout: float = 0.5,
device=None,
state_dict=None,
base_address="",
tt_cache_path=None,
tt_dtype=tt_lib.tensor.DataType.BFLOAT16,
) -> None:
super().__init__()
assert init_weights == False, "we are loading weights, not initializing them"
self.device = device
self.state_dict = state_dict
self.base_address = base_address

self.features = features
self.avgpool = fallback_ops.AdaptiveAvgPool2d((7, 7))

linear1_weight = state_dict[f"classifier.0.weight"]
linear1_weight = tt_lib.tensor.Tensor(
linear1_weight.reshape(-1).tolist(),
get_shape(linear1_weight.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
self.output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)

linear1_bias = state_dict[f"classifier.0.bias"]
linear1_bias = tt_lib.tensor.Tensor(
linear1_bias.reshape(-1).tolist(),
get_shape(linear1_bias.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)

linear2_weight = state_dict[f"classifier.3.weight"]
linear2_weight = tt_lib.tensor.Tensor(
linear2_weight.reshape(-1).tolist(),
get_shape(linear2_weight.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
linear1_weight = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.0.weight{tt_dtype}.bin")
linear1_bias = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.0.bias{tt_dtype}.bin")

linear2_bias = state_dict[f"classifier.3.bias"]
linear2_bias = tt_lib.tensor.Tensor(
linear2_bias.reshape(-1).tolist(),
get_shape(linear2_bias.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
linear2_weight = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.3.weight{tt_dtype}.bin")
linear2_bias = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.3.bias{tt_dtype}.bin")

linear3_weight = state_dict[f"classifier.6.weight"]
linear3_weight = tt_lib.tensor.Tensor(
linear3_weight.reshape(-1).tolist(),
get_shape(linear3_weight.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
linear3_weight = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.6.weight{tt_dtype}.bin")
linear3_bias = tt_lib.tensor.load_tensor(f"{tt_cache_path}classifier.6.bias{tt_dtype}.bin")

linear3_bias = state_dict[f"classifier.6.bias"]
linear3_bias = tt_lib.tensor.Tensor(
linear3_bias.reshape(-1).tolist(),
get_shape(linear3_bias.shape),
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
linear1 = TtLinear(linear1_weight.shape()[-1], linear1_weight.shape()[-2], linear1_weight, linear1_bias)

linear1 = tt_linear(linear1_weight, linear1_bias, self.device)
linear2 = TtLinear(linear2_weight.shape()[-1], linear2_weight.shape()[-2], linear2_weight, linear2_bias)

linear2 = tt_linear(linear2_weight, linear2_bias, self.device)

linear3 = tt_linear(linear3_weight, linear3_bias, self.device)
linear3 = TtLinear(linear3_weight.shape()[-1], linear3_weight.shape()[-2], linear3_weight, linear3_bias)

self.classifier = [
linear1,
Expand All @@ -110,15 +76,13 @@ def forward(self, tt_x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
assert batch_size == 1

for layer in self.features:
if layer is tt_lib.tensor.relu:
tt_x = layer(tt_x)
else:
tt_x = layer(tt_x)
tt_x = layer(tt_x)

batch, c, w, h = tt_x.shape()

tt_x = self.avgpool(tt_x)
tt_x = fallback_ops.reshape(tt_x, batch, 1, 1, c * w * h)
tt_x = format_tensor(tt_x, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
for layer in self.classifier:
tt_x = layer(tt_x)

Expand All @@ -132,6 +96,8 @@ def make_layers(
base_address="features",
device=None,
disable_conv_on_tt_device=True,
tt_cache_path=None,
tt_dtype=tt_lib.tensor.DataType.BFLOAT16,
) -> nn.Sequential:
layers: List = []
in_channels = 3
Expand All @@ -145,9 +111,7 @@ def make_layers(
if not batch_norm:
ind = len(layers)
conv2d_params = [v, in_channels, 3, 3, 1, 1, 1, 1, 1, 1]
if not disable_conv_on_tt_device and is_conv_supported_on_device(
conv2d_params
):
if not disable_conv_on_tt_device and is_conv_supported_on_device(conv2d_params):
assert device is not None
conv2d_weight = state_dict[f"{base_address}.{ind}.weight"]
conv2d_bias = state_dict[f"{base_address}.{ind}.bias"].tolist()
Expand All @@ -158,16 +122,8 @@ def make_layers(
conv2d_bias,
)
else:
weight = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.{ind}.weight"],
device=device,
put_on_device=False,
)
bias = torch_to_tt_tensor_rm(
state_dict[f"{base_address}.{ind}.bias"],
device=device,
put_on_device=False,
)
weight = tt_lib.tensor.load_tensor(f"{tt_cache_path}{base_address}.{ind}.weight{tt_dtype}.bin")
bias = tt_lib.tensor.load_tensor(f"{tt_cache_path}{base_address}.{ind}.bias{tt_dtype}.bin")
conv2d = fallback_ops.Conv2d(
weights=weight,
biases=bias,
Expand Down Expand Up @@ -235,17 +191,20 @@ def make_layers(
}


def _vgg(features, init_weights, device, state_dict, base_address=""):
def _vgg(features, init_weights, device, base_address="", tt_cache_path=None, tt_dtype=tt_lib.tensor.DataType.BFLOAT16):
return TtVGG(
features,
init_weights=init_weights,
device=device,
state_dict=state_dict,
base_address=base_address,
tt_cache_path=tt_cache_path,
tt_dtype=tt_dtype,
)


def vgg16(device, disable_conv_on_tt_device=True) -> TtVGG:
def vgg16(
device, disable_conv_on_tt_device=True, tt_cache_path=None, tt_dtype=tt_lib.tensor.DataType.BFLOAT16
) -> TtVGG:
torch_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
torch_vgg.eval()
state_dict = torch_vgg.state_dict()
Expand All @@ -256,15 +215,20 @@ def vgg16(device, disable_conv_on_tt_device=True) -> TtVGG:
state_dict=state_dict,
device=device,
disable_conv_on_tt_device=disable_conv_on_tt_device,
tt_cache_path=tt_cache_path,
tt_dtype=tt_dtype,
),
init_weights=False,
device=device,
state_dict=state_dict,
tt_cache_path=tt_cache_path,
tt_dtype=tt_dtype,
)
return model


def vgg11(device, disable_conv_on_tt_device=True) -> TtVGG:
def vgg11(
device, disable_conv_on_tt_device=True, tt_cache_path=None, tt_dtype=tt_lib.tensor.DataType.BFLOAT16
) -> TtVGG:
torch_vgg = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
torch_vgg.eval()
state_dict = torch_vgg.state_dict()
Expand All @@ -275,9 +239,12 @@ def vgg11(device, disable_conv_on_tt_device=True) -> TtVGG:
state_dict=state_dict,
device=device,
disable_conv_on_tt_device=disable_conv_on_tt_device,
tt_cache_path=tt_cache_path,
tt_dtype=tt_dtype,
),
init_weights=False,
device=device,
state_dict=state_dict,
tt_cache_path=tt_cache_path,
tt_dtype=tt_dtype,
)
return model
23 changes: 0 additions & 23 deletions models/experimental/vgg/tt/vgg_helper_funcs.py

This file was deleted.

Loading

0 comments on commit 5d626ec

Please sign in to comment.