Skip to content

Commit

Permalink
Merge pull request supervisely#41 from beyse/fix_iou
Browse files Browse the repository at this point in the history
fix rendering of binary mask for iou calculation
  • Loading branch information
max-supervisely authored Oct 21, 2020
2 parents b9ed1c0 + b264e11 commit d1915e3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions supervisely_lib/metric/iou_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def get_iou(mask_1, mask_2):
def _iou_log_line(iou, intersection, union):
return 'IoU = {:.6f}, mean intersection = {:.6f}, mean union = {:.6f}'.format(iou, intersection, union)

def render_labels_as_binary_mask(labels, class_title, mask):
for label in ann.labels:
if label.obj_class.name == class_title:
label.geometry.draw(mask, True)

class IoUMetric(MetricsBase):

Expand All @@ -41,8 +45,8 @@ def add_pair(self, ann_gt, ann_pred):
img_size = ann_gt.img_size
for cls_gt, cls_pred in self._class_mapping.items():
mask_gt, mask_pred = np.full(img_size, False), np.full(img_size, False)
render_labels_for_class_name(ann_gt.labels, cls_gt, mask_gt)
render_labels_for_class_name(ann_pred.labels, cls_pred, mask_pred)
render_labels_as_binary_mask(ann_gt.labels, cls_gt, mask_gt)
render_labels_as_binary_mask(ann_pred.labels, cls_pred, mask_pred)
class_pair_counters = self._counters[cls_gt]
class_pair_counters[INTERSECTION] += get_intersection(mask_gt, mask_pred)
class_pair_counters[UNION] += get_union(mask_gt, mask_pred)
Expand Down

0 comments on commit d1915e3

Please sign in to comment.