Skip to content

Commit

Permalink
#4668: Yolov5 GS Demo Benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Jan 24, 2024
1 parent 7447268 commit 4a83d81
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 52 deletions.
74 changes: 22 additions & 52 deletions models/experimental/yolov5/reference/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def ap_per_class(
# Returns
The average precision as computed in py-faster-rcnn.
"""

conf_numeric = np.array(list(conf), dtype=np.float32)
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
i = np.argsort(conf_numeric)[::-1]
tp, conf, pred_cls = np.array(tp)[i], np.array(conf)[i].astype(float), np.array(pred_cls)[i]

# Find unique classes
unique_classes, nt = np.unique(target_cls, return_counts=True)
Expand All @@ -65,6 +65,8 @@ def ap_per_class(
# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
if isinstance(names, tuple):
names = {}
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = nt[ci] # number of labels
Expand All @@ -78,9 +80,7 @@ def ap_per_class(

# Recall
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(
-px, -conf[i], recall[:, 0], left=0
) # negative x, xp because xp decreases
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases

# Precision
precision = tpc / (tpc + fpc) # precision curve
Expand All @@ -94,21 +94,14 @@ def ap_per_class(

# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
names = [
v for k, v in names.items() if k in unique_classes
] # list: only classes that have data
names = dict(enumerate(names)) # to dict
names = {v for k, v in names.items() if k in unique_classes}
if isinstance(names, tuple):
names = dict(enumerate(names))
if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f"{prefix}PR_curve.png", names)
plot_mc_curve(
px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1"
)
plot_mc_curve(
px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision"
)
plot_mc_curve(
px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall"
)
plot_mc_curve(px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1")
plot_mc_curve(px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision")
plot_mc_curve(px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall")

i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
Expand Down Expand Up @@ -176,11 +169,7 @@ def process_batch(self, detections, labels):

x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = (
torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
.cpu()
.numpy()
)
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
Expand Down Expand Up @@ -212,9 +201,7 @@ def tp_fp(self):
def plot(self, normalize=True, save_dir="", names=()):
import seaborn as sn

array = self.matrix / (
(self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1
) # normalize columns
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
Expand All @@ -223,9 +210,7 @@ def plot(self, normalize=True, save_dir="", names=()):
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (names + ["background"]) if labels else "auto"
with warnings.catch_warnings():
warnings.simplefilter(
"ignore"
) # suppress empty matrix RuntimeWarning: All-NaN slice encountered
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(
array,
ax=ax,
Expand Down Expand Up @@ -275,30 +260,19 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
# IoU
iou = inter / union
if CIoU or DIoU or GIoU:
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(
b2_x1
) # convex (smallest enclosing box) width
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = (
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
) / 4 # center dist ** 2
if (
CIoU
): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (
torch.atan(w2 / h2) - torch.atan(w1 / h1)
).pow(2)
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return (
iou - (c_area - union) / c_area
) # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou # IoU


Expand Down Expand Up @@ -351,9 +325,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
wh1 = wh1[:, None] # [N,1,2]
wh2 = wh2[None] # [1,M,2]
inter = torch.min(wh1, wh2).prod(2) # [N,M]
return inter / (
wh1.prod(2) + wh2.prod(2) - inter + eps
) # iou = inter / (area1 + area2 - inter)
return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)


# Plots ----------------------------------------------------------------------------------------------------------------
Expand All @@ -366,9 +338,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()):

if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(
px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}"
) # plot(recall, precision)
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)

Expand Down
237 changes: 237 additions & 0 deletions models/experimental/yolov5/tests/test_perf_accuracy_yolov5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import os
import sys
import torch
import tt_lib
import pytest
import numpy as np

from loguru import logger
from pathlib import Path

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))

from models.perf.perf_utils import prep_perf_report
from models.experimental.yolov5.reference.models.common import DetectMultiBackend
from models.experimental.yolov5.tt.yolov5_detection_model import yolov5s_detection_model
from models.utility_functions import (
torch2tt_tensor,
Profiler,
disable_persistent_kernel_cache,
enable_persistent_kernel_cache,
)
from models.experimental.yolov5.reference.utils.metrics import ap_per_class
from models.experimental.yolov5.reference.utils.general import check_img_size
from models.experimental.yolov5.reference.utils.dataloaders import LoadImages
from models.experimental.yolov5.reference.utils.general import (
non_max_suppression,
scale_boxes,
xyxy2xywh,
)


BATCH_SIZE = 1


def run_perf_yolov5s(
model_location_generator,
expected_inference_time,
expected_compile_time,
iterations,
device,
):
profiler = Profiler()
disable_persistent_kernel_cache()
first_key = f"first_iter"
second_key = f"second_iter"
third_key = f"third_iter"
cpu_key = f"ref_key"
comments = f"yolov5s"

refence_model = DetectMultiBackend(
ROOT / "yolov5s.pt",
device=torch.device("cpu"),
dnn=False,
data=None,
fp16=False,
)

refence_module = refence_model.model
tt_module = yolov5s_detection_model(device)

test_input = torch.rand(1, 3, 640, 480)
tt_inputs = torch2tt_tensor(test_input, device)

data_images_path = "/mnt/MLPerf/tt_dnn-models/ssd/coco128/images/train2017"
data_labels_path = "/mnt/MLPerf/tt_dnn-models/ssd/coco128/labels/train2017"
image_files = os.listdir(data_images_path)

stride = max(int(max(refence_module.stride)), 32)
imgsz = check_img_size((640, 640), s=stride)

with torch.no_grad():
tt_module.eval()
refence_module.eval()

profiler.start(cpu_key)
logits = refence_module(test_input)
tt_lib.device.Synchronize(device)
profiler.end(cpu_key)

profiler.start(first_key)
tt_output = tt_module(tt_inputs)
tt_lib.device.Synchronize(device)
profiler.end(first_key)
del tt_output

enable_persistent_kernel_cache()

profiler.start(second_key)
tt_output = tt_module(tt_inputs)
tt_lib.device.Synchronize(device)
profiler.end(second_key)
del tt_output

iteration = 0
ap_list = []
all_predictions = []

profiler.start(third_key)
while iteration < iterations:
image_file = image_files[iteration]
image_path = os.path.join(data_images_path, image_file)
dataset = LoadImages(image_path, img_size=imgsz, stride=stride, auto=True)
names = refence_model.names

for path, im, im0s, _, s in dataset:
im = torch.from_numpy(im)
im = im.float()
im /= 255
if len(im.shape) == 3:
im = im[None]

image_file = Path(path).stem
label_file = image_file + ".txt"
label_path = os.path.join(data_labels_path, label_file)

if os.path.exists(label_path):
all_ground_truths = []
lines = [l.strip().split() for l in open(label_path, "r").readlines()]
reference_labels = [{"class": int(line[0]), "bbox": list(map(float, line[1:]))} for line in lines]
gt_boxes = [label["bbox"] for label in reference_labels]
gt_classes = [label["class"] for label in reference_labels]
all_ground_truths.append(reference_labels)

tt_im = torch2tt_tensor(im, device)
pred = tt_module(tt_im)

conf_thres, iou_thres = 0.25, 0.45
classes = None
agnostic_nms = False

pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=1000)
for i, det in enumerate(pred):
s += "%gx%g " % im.shape[2:]
gn = torch.tensor(im0s.shape)[[1, 0, 1, 0]]
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0s.shape).round()

for c in det[:, 5].unique():
n = (det[:, 5] == c).sum()
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "

for *xyxy, conf, cls in reversed(det):
if True:
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
line = (cls, *xywh)
c = int(cls)
label = None if False else (f"{names[c]} {conf:.2f}")
prediction = {"class": c, "confidence": f"{conf:.2f}", "bbox": xywh}
all_predictions.append(prediction)

iteration += 1

_, _, _, _, _, ap, _ = ap_per_class(
tp=[pred["bbox"] for pred in all_predictions],
conf=[float(pred["confidence"]) for pred in all_predictions],
pred_cls=[int(pred["class"]) for pred in all_predictions],
target_cls=gt_classes,
)
ap_list.append(ap)

tt_lib.device.Synchronize(device)
profiler.end(third_key)

mAP = np.mean(ap_list)
first_iter_time = profiler.get(first_key)
second_iter_time = profiler.get(second_key)
third_iter_time = profiler.get(third_key)
cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - second_iter_time

prep_perf_report(
model_name="yolov5s",
batch_size=BATCH_SIZE,
inference_and_compile_time=first_iter_time,
inference_time=second_iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comments,
inference_time_cpu=cpu_time,
)

logger.info(f"yolov5 mAP: {mAP}")
logger.info(f"{comments} inference time: {second_iter_time}")
logger.info(f"yolov5 compile time: {compile_time}")
logger.info(f"yolov5 inference time for {iterations} Samples: {third_iter_time}")


@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"expected_inference_time, expected_compile_time, iterations",
((2.5, 7.8, 5),),
)
def test_perf_bare_metal(
model_location_generator,
expected_inference_time,
expected_compile_time,
iterations,
device,
reset_seeds,
):
run_perf_yolov5s(
model_location_generator,
expected_inference_time,
expected_compile_time,
iterations,
device,
)


@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize(
"expected_inference_time, expected_compile_time, iterations",
((2.3, 0.85, 5),),
)
def test_perf_virtual_machine(
model_location_generator,
expected_inference_time,
expected_compile_time,
iterations,
device,
reset_seeds,
):
run_perf_yolov5s(
model_location_generator,
expected_inference_time,
expected_compile_time,
iterations,
device,
)

0 comments on commit 4a83d81

Please sign in to comment.