From 6e94c4a87c1249f830543f2f01943d89c6fb234c Mon Sep 17 00:00:00 2001 From: kkeerthana0573 Date: Fri, 8 Dec 2023 10:02:46 +0000 Subject: [PATCH] #4622: Yolov3 GS demo Benchmarking --- .../yolov3/reference/utils/metrics.py | 74 ++--- .../yolov3/tests/test_perf_accuracy_yolov3.py | 261 ++++++++++++++++++ 2 files changed, 283 insertions(+), 52 deletions(-) create mode 100644 models/experimental/yolov3/tests/test_perf_accuracy_yolov3.py diff --git a/models/experimental/yolov3/reference/utils/metrics.py b/models/experimental/yolov3/reference/utils/metrics.py index db99caa082a..75fe956de57 100644 --- a/models/experimental/yolov3/reference/utils/metrics.py +++ b/models/experimental/yolov3/reference/utils/metrics.py @@ -55,10 +55,10 @@ def ap_per_class( # Returns The average precision as computed in py-faster-rcnn. """ - + conf_numeric = np.array(list(conf), dtype=np.float32) # Sort by objectness - i = np.argsort(-conf) - tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + i = np.argsort(conf_numeric)[::-1] + tp, conf, pred_cls = np.array(tp)[i], np.array(conf)[i].astype(float), np.array(pred_cls)[i] # Find unique classes unique_classes, nt = np.unique(target_cls, return_counts=True) @@ -67,6 +67,8 @@ def ap_per_class( # Create Precision-Recall curve and compute AP for each class px, py = np.linspace(0, 1, 1000), [] # for plotting ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + if isinstance(names, tuple): + names = {} for ci, c in enumerate(unique_classes): i = pred_cls == c n_l = nt[ci] # number of labels @@ -80,9 +82,7 @@ def ap_per_class( # Recall recall = tpc / (n_l + eps) # recall curve - r[ci] = np.interp( - -px, -conf[i], recall[:, 0], left=0 - ) # negative x, xp because xp decreases + r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases # Precision precision = tpc / (tpc + fpc) # precision curve @@ -96,21 +96,14 @@ def ap_per_class( # Compute F1 (harmonic mean of precision and recall) f1 = 2 * p * r / (p + r + eps) - names = [ - v for k, v in names.items() if k in unique_classes - ] # list: only classes that have data - names = dict(enumerate(names)) # to dict + names = {v for k, v in names.items() if k in unique_classes} + if isinstance(names, tuple): + names = dict(enumerate(names)) if plot: plot_pr_curve(px, py, ap, Path(save_dir) / f"{prefix}PR_curve.png", names) - plot_mc_curve( - px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1" - ) - plot_mc_curve( - px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision" - ) - plot_mc_curve( - px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall" - ) + plot_mc_curve(px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1") + plot_mc_curve(px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision") + plot_mc_curve(px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall") i = smooth(f1.mean(0), 0.1).argmax() # max F1 index p, r, f1 = p[:, i], r[:, i], f1[:, i] @@ -178,11 +171,7 @@ def process_batch(self, detections, labels): x = torch.where(iou > self.iou_thres) if x[0].shape[0]: - matches = ( - torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1) - .cpu() - .numpy() - ) + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]] @@ -215,9 +204,7 @@ def tp_fp(self): def plot(self, normalize=True, save_dir="", names=()): import seaborn as sn - array = self.matrix / ( - (self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1 - ) # normalize columns + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) @@ -226,9 +213,7 @@ def plot(self, normalize=True, save_dir="", names=()): labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels ticklabels = (names + ["background"]) if labels else "auto" with warnings.catch_warnings(): - warnings.simplefilter( - "ignore" - ) # suppress empty matrix RuntimeWarning: All-NaN slice encountered + warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered sn.heatmap( array, ax=ax, @@ -278,30 +263,19 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 # IoU iou = inter / union if CIoU or DIoU or GIoU: - cw = b1_x2.maximum(b2_x2) - b1_x1.minimum( - b2_x1 - ) # convex (smallest enclosing box) width + cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw**2 + ch**2 + eps # convex diagonal squared - rho2 = ( - (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 - + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 - ) / 4 # center dist ** 2 - if ( - CIoU - ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi**2) * ( - torch.atan(w2 / h2) - torch.atan(w1 / h1) - ).pow(2) + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU return iou - rho2 / c2 # DIoU c_area = cw * ch + eps # convex area - return ( - iou - (c_area - union) / c_area - ) # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf return iou # IoU @@ -354,9 +328,7 @@ def wh_iou(wh1, wh2, eps=1e-7): wh1 = wh1[:, None] # [N,1,2] wh2 = wh2[None] # [1,M,2] inter = torch.min(wh1, wh2).prod(2) # [N,M] - return inter / ( - wh1.prod(2) + wh2.prod(2) - inter + eps - ) # iou = inter / (area1 + area2 - inter) + return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter) # Plots ---------------------------------------------------------------------------------------------------------------- @@ -370,9 +342,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()): if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py.T): - ax.plot( - px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}" - ) # plot(recall, precision) + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) else: ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) diff --git a/models/experimental/yolov3/tests/test_perf_accuracy_yolov3.py b/models/experimental/yolov3/tests/test_perf_accuracy_yolov3.py new file mode 100644 index 00000000000..9f954629d4b --- /dev/null +++ b/models/experimental/yolov3/tests/test_perf_accuracy_yolov3.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os +import torch +import tt_lib +import pytest +import numpy as np + +from pathlib import Path +from loguru import logger +from collections import defaultdict + +from models.perf.perf_utils import prep_perf_report +from models.experimental.yolov3.reference.models.common import DetectMultiBackend +from models.experimental.yolov3.tt.yolov3_detection_model import TtDetectionModel +from models.experimental.yolov3.reference.utils.dataloaders import LoadImages +from models.utility_functions import ( + torch2tt_tensor, + Profiler, + disable_persistent_kernel_cache, + enable_persistent_kernel_cache, +) +from models.experimental.yolov3.reference.utils.general import ( + check_img_size, + non_max_suppression, + scale_boxes, + xyxy2xywh, + check_img_size, +) +from models.experimental.yolov3.reference.utils.metrics import * + + +BATCH_SIZE = 1 + + +def run_perf_yolov3(expected_inference_time, expected_compile_time, model_location_generator, device, iterations): + profiler = Profiler() + disable_persistent_kernel_cache() + first_key = "first_iter" + second_key = "second_iter" + third_key = "third_iter" + cpu_key = "ref_key" + comments = "yolov3-fused" + + model_path = model_location_generator("models", model_subdir="Yolo") + data_path = model_location_generator("data", model_subdir="Yolo") + + data_image_path = str(data_path / "images") + data_coco = str(data_path / "coco128.yaml") + model_config_path = str(data_path / "yolov3.yaml") + weights_loc = str(model_path / "yolov3.pt") + + reference_model = DetectMultiBackend(weights_loc, device=torch.device("cpu"), dnn=False, data=data_coco, fp16=False) + state_dict = reference_model.state_dict() + reference_model = reference_model.model + reference_model.eval() + + tt_module = TtDetectionModel( + cfg=model_config_path, + state_dict=state_dict, + base_address="model.model", + device=device, + ) + tt_module.eval() + stride = max(int(max(reference_model.stride)), 32) + imgsz = check_img_size((640, 640), s=stride) + dataset = LoadImages(data_image_path, img_size=imgsz, stride=stride, auto=True) + + path, im, _, _, _ = next(iter(dataset)) + im = torch.from_numpy(im) + im = im.float() + im /= 255 + if len(im.shape) == 3: + im = im[None] + + tt_im = torch2tt_tensor(im, device, tt_layout=tt_lib.tensor.Layout.ROW_MAJOR) + + with torch.no_grad(): + profiler.start(cpu_key) + pt_out = reference_model(im) + tt_lib.device.Synchronize(device) + profiler.end(cpu_key) + + profiler.start(first_key) + tt_out = tt_module(tt_im) + tt_lib.device.Synchronize(device) + profiler.end(first_key) + del tt_out + + enable_persistent_kernel_cache() + + profiler.start(second_key) + tt_out = tt_module(tt_im) + tt_lib.device.Synchronize(device) + profiler.end(second_key) + del tt_out + + data_images_path = "/mnt/MLPerf/tt_dnn-models/ssd/coco128/images/train2017" + data_labels_path = "/mnt/MLPerf/tt_dnn-models/ssd/coco128/labels/train2017" + image_files = os.listdir(data_images_path) + iteration = 0 + ap_list = [] + all_predictions = [] + + profiler.start(third_key) + while iteration < iterations: + image_file = image_files[iteration] + image_path = os.path.join(data_images_path, image_file) + dataset = LoadImages(image_path, img_size=imgsz, stride=stride, auto=True) + names = reference_model.names + + path, im, im0s, _, s = next(iter(dataset)) + im = torch.from_numpy(im) + im = im.float() + im /= 255 + if len(im.shape) == 3: + im = im[None] + + image_file = Path(path).stem + label_file = image_file + ".txt" + label_path = os.path.join(data_labels_path, label_file) + + if os.path.exists(label_path): + all_ground_truths = [] + lines = [l.strip().split() for l in open(label_path, "r").readlines()] + reference_labels = [{"class": int(line[0]), "bbox": list(map(float, line[1:]))} for line in lines] + gt_boxes = [label["bbox"] for label in reference_labels] + gt_classes = [label["class"] for label in reference_labels] + all_ground_truths.append(reference_labels) + + tt_im = torch2tt_tensor(im, device) + pred = tt_module(tt_im) + + conf_thres, iou_thres = 0.25, 0.45 + classes = None + agnostic_nms = False + + pred = non_max_suppression( + prediction=pred, + conf_thres=conf_thres, + iou_thres=iou_thres, + classes=classes, + agnostic=agnostic_nms, + max_det=1000, + ) + + for i, det in enumerate(pred): + s += "%gx%g " % im.shape[2:] + gn = torch.tensor(im0s.shape)[[1, 0, 1, 0]] + + if len(det): + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0s.shape).round() + class_confidence = defaultdict(float) + class_bbox = {} + + for bbox in det: + label = int(bbox[5]) + confidence = float(bbox[4]) + bbox_info = { + "x_center": float(bbox[0]), + "y_center": float(bbox[1]), + "width": float(bbox[2] - bbox[0]), + "height": float(bbox[3] - bbox[1]), + } + + if confidence > class_confidence[label]: + class_confidence[label] = confidence + class_bbox[label] = bbox_info + + for *xyxy, conf, cls in reversed(det): + if True: + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + + c = int(cls) + label = None if False else (f"{names[c]} {conf:.2f}") + prediction = {"class": c, "confidence": f"{conf:.2f}", "bbox": xywh} + all_predictions.append(prediction) + + iteration += 1 + + _, _, _, _, _, ap, _ = ap_per_class( + tp=[pred["bbox"] for pred in all_predictions], + conf=[float(pred["confidence"]) for pred in all_predictions], + pred_cls=[int(pred["class"]) for pred in all_predictions], + target_cls=gt_classes, + ) + ap_list.append(ap) + tt_lib.device.Synchronize(device) + profiler.end(third_key) + + mAP = np.mean(ap_list) + first_iter_time = profiler.get(first_key) + second_iter_time = profiler.get(second_key) + third_iter_time = profiler.get(third_key) + cpu_time = profiler.get(cpu_key) + + prep_perf_report( + model_name="yolov3", + batch_size=BATCH_SIZE, + inference_and_compile_time=first_iter_time, + inference_time=second_iter_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments=comments, + inference_time_cpu=cpu_time, + ) + + compile_time = first_iter_time - second_iter_time + + logger.info(f"yolov3 mAP: {mAP}") + logger.info(f"yolov3 {comments} inference time: {second_iter_time}") + logger.info(f"yolov3 compile time: {compile_time}") + logger.info(f"yolov3 inference time for {iterations} Samples: {third_iter_time}") + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "expected_inference_time, expected_compile_time, iterations", + ( + ( + 5.86, + 9.47, + 10, + ), + ), +) +def test_perf_bare_metal( + use_program_cache, + expected_inference_time, + expected_compile_time, + model_location_generator, + device, + iterations, + reset_seeds, +): + run_perf_yolov3(expected_inference_time, expected_compile_time, model_location_generator, device, iterations) + + +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize( + "expected_inference_time, expected_compile_time, iterations", + ( + ( + 5.8, + 0.7, + 10, + ), + ), +) +def test_perf_virtual_machine( + use_program_cache, + expected_inference_time, + expected_compile_time, + model_location_generator, + device, + iterations, + reset_seeds, +): + run_perf_yolov3(expected_inference_time, expected_compile_time, model_location_generator, device, iterations)