Skip to content

Commit

Permalink
#4622: Yolov3 GS demo Benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Jan 22, 2024
1 parent 7b9aae9 commit 311132f
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 52 deletions.
74 changes: 22 additions & 52 deletions models/experimental/yolov3/reference/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def ap_per_class(
# Returns
The average precision as computed in py-faster-rcnn.
"""

conf_numeric = np.array(list(conf), dtype=np.float32)
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
i = np.argsort(conf_numeric)[::-1]
tp, conf, pred_cls = np.array(tp)[i], np.array(conf)[i].astype(float), np.array(pred_cls)[i]

# Find unique classes
unique_classes, nt = np.unique(target_cls, return_counts=True)
Expand All @@ -67,6 +67,8 @@ def ap_per_class(
# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
if isinstance(names, tuple):
names = {}
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = nt[ci] # number of labels
Expand All @@ -80,9 +82,7 @@ def ap_per_class(

# Recall
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(
-px, -conf[i], recall[:, 0], left=0
) # negative x, xp because xp decreases
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases

# Precision
precision = tpc / (tpc + fpc) # precision curve
Expand All @@ -96,21 +96,14 @@ def ap_per_class(

# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
names = [
v for k, v in names.items() if k in unique_classes
] # list: only classes that have data
names = dict(enumerate(names)) # to dict
names = {v for k, v in names.items() if k in unique_classes}
if isinstance(names, tuple):
names = dict(enumerate(names))
if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f"{prefix}PR_curve.png", names)
plot_mc_curve(
px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1"
)
plot_mc_curve(
px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision"
)
plot_mc_curve(
px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall"
)
plot_mc_curve(px, f1, Path(save_dir) / f"{prefix}F1_curve.png", names, ylabel="F1")
plot_mc_curve(px, p, Path(save_dir) / f"{prefix}P_curve.png", names, ylabel="Precision")
plot_mc_curve(px, r, Path(save_dir) / f"{prefix}R_curve.png", names, ylabel="Recall")

i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
Expand Down Expand Up @@ -178,11 +171,7 @@ def process_batch(self, detections, labels):

x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = (
torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
.cpu()
.numpy()
)
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
Expand Down Expand Up @@ -215,9 +204,7 @@ def tp_fp(self):
def plot(self, normalize=True, save_dir="", names=()):
import seaborn as sn

array = self.matrix / (
(self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1
) # normalize columns
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
Expand All @@ -226,9 +213,7 @@ def plot(self, normalize=True, save_dir="", names=()):
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (names + ["background"]) if labels else "auto"
with warnings.catch_warnings():
warnings.simplefilter(
"ignore"
) # suppress empty matrix RuntimeWarning: All-NaN slice encountered
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(
array,
ax=ax,
Expand Down Expand Up @@ -278,30 +263,19 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
# IoU
iou = inter / union
if CIoU or DIoU or GIoU:
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(
b2_x1
) # convex (smallest enclosing box) width
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = (
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
) / 4 # center dist ** 2
if (
CIoU
): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (
torch.atan(w2 / h2) - torch.atan(w1 / h1)
).pow(2)
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return (
iou - (c_area - union) / c_area
) # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou # IoU


Expand Down Expand Up @@ -354,9 +328,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
wh1 = wh1[:, None] # [N,1,2]
wh2 = wh2[None] # [1,M,2]
inter = torch.min(wh1, wh2).prod(2) # [N,M]
return inter / (
wh1.prod(2) + wh2.prod(2) - inter + eps
) # iou = inter / (area1 + area2 - inter)
return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)


# Plots ----------------------------------------------------------------------------------------------------------------
Expand All @@ -370,9 +342,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()):

if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(
px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}"
) # plot(recall, precision)
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)

Expand Down
Loading

0 comments on commit 311132f

Please sign in to comment.