From b264e11f99d3c7f452b28f33f7a24aa85e3e0140 Mon Sep 17 00:00:00 2001 From: Sebastian Beyer Date: Thu, 8 Oct 2020 17:13:57 +0200 Subject: [PATCH] fix rendering of binary mask for iou calculation --- supervisely_lib/metric/iou_metric.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/supervisely_lib/metric/iou_metric.py b/supervisely_lib/metric/iou_metric.py index d48faca04..ff9a32460 100644 --- a/supervisely_lib/metric/iou_metric.py +++ b/supervisely_lib/metric/iou_metric.py @@ -27,6 +27,10 @@ def get_iou(mask_1, mask_2): def _iou_log_line(iou, intersection, union): return 'IoU = {:.6f}, mean intersection = {:.6f}, mean union = {:.6f}'.format(iou, intersection, union) +def render_labels_as_binary_mask(labels, class_title, mask): + for label in ann.labels: + if label.obj_class.name == class_title: + label.geometry.draw(mask, True) class IoUMetric(MetricsBase): @@ -41,8 +45,8 @@ def add_pair(self, ann_gt, ann_pred): img_size = ann_gt.img_size for cls_gt, cls_pred in self._class_mapping.items(): mask_gt, mask_pred = np.full(img_size, False), np.full(img_size, False) - render_labels_for_class_name(ann_gt.labels, cls_gt, mask_gt) - render_labels_for_class_name(ann_pred.labels, cls_pred, mask_pred) + render_labels_as_binary_mask(ann_gt.labels, cls_gt, mask_gt) + render_labels_as_binary_mask(ann_pred.labels, cls_pred, mask_pred) class_pair_counters = self._counters[cls_gt] class_pair_counters[INTERSECTION] += get_intersection(mask_gt, mask_pred) class_pair_counters[UNION] += get_union(mask_gt, mask_pred)