From 3598bcb17bd5e0157a9cc3d820d00c8f37331ef8 Mon Sep 17 00:00:00 2001 From: kumar-utkarsh0317 Date: Mon, 29 Jan 2024 18:32:03 +0530 Subject: [PATCH] added a custom loss function for the yolov1_tiny model --- models/yolov1_tiny/yolo.py | 92 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/models/yolov1_tiny/yolo.py b/models/yolov1_tiny/yolo.py index e35f591..5d9ff6a 100644 --- a/models/yolov1_tiny/yolo.py +++ b/models/yolov1_tiny/yolo.py @@ -47,3 +47,95 @@ def make_features(self): def make_classifier(self, num_bboxes, num_classes): return nn.Sequential(nn.Sequential(nn.Linear(in_features=256 * 7 * 7, out_features=1470), nn.Sigmoid())) + + +# model output is like +# [x1, y1, w1, h1, c1, x2, y2, w2, h2, c2, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20] +class CustomyoloLoss(nn.Module): + def __init__(self): + super(CustomMSELoss, self).__init__() + + + # predictions => [c1, c2] + # targets => [c*] + def objectiveness_loss(self, predictions, target): + c1 = predictions[0] + c2 = predictions[1] + c = target[0] + if c == 1: + return torch.square(c1 - c) if c1 > c2 else torch.square(c2 - c) + else: + return torch.sum(torch.square(predictions)) + + # predictions = [p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20] + # targets = [p1*, p2*, p3*, p4*, p5*, p6*, p7*, p8*, p9*, p10*, p11*, p12*, p13*, p14*, p15*, p16*, p17*, p18*, p19*, p20*] + def classification_loss(self, predictions, targets): + return torch.sum(torch.square(predictions - targets)) + + # predictions = [x1, y1, w1, h1, c1, x2, y2, w2, h2, c2] + # targets = [x*, y*, w*, h*] + def box_regression_loss(self, predictions, targets): + c1 = predictions[4] + c2 = predictions[9] + t_box_center = targets[0:2] + t_h_w = targets[2:4] + + if c1 > c2: + p_box_center = predictions[0:2] + p_h_w = predictions[2:4] + else: + p_box_center = predictions[5:7] + p_h_w = predictions[7:9] + + return torch.sum(torch.square(p_box_center - t_box_center)) + torch.sum(torch.square(torch.sqrt(p_h_w) - torch.sqrt(t_h_w))) + + + # predictions => (50, 7*7*30) => (50, 1470) + # target => (50, 7*7*25) => (50, 1225) + def forward(self, predictions, targets): + + num_batches = predictions.shape[0] + objectiveness_loss = 0 + class_loss = 0 + box_loss = 0 + + predictions_ = predictions.reshape((num_batches, 7, 7, 30)) + targets_ = targets.reshape((num_batches, 7, 7, 25)) + + for n_sample in range(num_batches): + # data => (7*7*30) + for row in range(7): + for col in range(7): + + c1 = predictions_[n_sample, row, col, 4] + c2 = predictions_[n_sample, row, col, 9] + c = targets_[n_sample, row, col, 4] + + object_present = True if c == 1 else False + + + if object_present: + objectiveness_loss += self.objectiveness_loss( + torch.cat((predictions_[n_sample, row, col, 4:5], predictions_[n_sample, row, col, 9:10])), + targets_[n_sample, row, col, 4:5] + ) + + class_loss += self.classification_loss( + predictions_[n_sample, row, col, 10:], + targets_[n_sample, row, col, 5:] + ) + + box_loss += 5 * self.box_regression_loss( + predictions_[n_sample, row, col, :10], + targets_[n_sample, row, col, :5] + ) + + else: + objectiveness_loss += 0.5 * self.objectiveness_loss( + torch.cat((predictions_[n_sample, row, col, 4:5], predictions_[n_sample, row, col, 9:10])), + targets_[n_sample, row, col, 4:5] + ) + + # overall loss will be the sum of all the loss + loss = objectiveness_loss + class_loss + box_loss + return loss \ No newline at end of file