diff --git a/class_config/colored_robots.yaml b/class_config/colored_robots.yaml new file mode 100644 index 0000000..8ba9b6a --- /dev/null +++ b/class_config/colored_robots.yaml @@ -0,0 +1,6 @@ +group_classes: + - robot_red + - robot_blue + - robot_unknown + +surrogate_class: robot diff --git a/class_config/default.yaml b/class_config/default.yaml new file mode 100644 index 0000000..83f72fd --- /dev/null +++ b/class_config/default.yaml @@ -0,0 +1,2 @@ +group_classes: +surrogate_class: "" diff --git a/yoeo/detect.py b/yoeo/detect.py index 6e95cad..f6f043a 100755 --- a/yoeo/detect.py +++ b/yoeo/detect.py @@ -12,12 +12,14 @@ from torch.utils.data import DataLoader from torch.autograd import Variable -from typing import Optional, List +from typing import Optional from imgaug.augmentables.segmaps import SegmentationMapsOnImage from yoeo.models import load_model -from yoeo.utils.utils import load_classes, rescale_boxes, non_max_suppression, print_environment_info, rescale_segmentation +from yoeo.utils.class_config import ClassConfig +from yoeo.utils.dataclasses import ClassNames, GroupConfig +from yoeo.utils.utils import rescale_boxes, non_max_suppression, print_environment_info, rescale_segmentation from yoeo.utils.datasets import ImageFolder from yoeo.utils.transforms import Resize, DEFAULT_TRANSFORMS @@ -26,9 +28,9 @@ from matplotlib.ticker import NullLocator -def detect_directory(model_path, weights_path, img_path, classes, output_path, +def detect_directory(model_path, weights_path, img_path, class_config: ClassConfig, output_path, batch_size=8, img_size=416, n_cpu=8, conf_thres=0.5, nms_thres=0.5, - robot_class_ids: Optional[List[int]] = None): + ): """Detects objects on all images in specified directory and saves output images with drawn detections. :param model_path: Path to model definition file (.cfg) @@ -37,8 +39,8 @@ def detect_directory(model_path, weights_path, img_path, classes, output_path, :type weights_path: str :param img_path: Path to directory with images to inference :type img_path: str - :param classes: List of class names - :type classes: [str] + :param class_config: Class configuration + :type class_config: ClassConfig :param output_path: Path to output directory :type output_path: str :param batch_size: Size of each image batch, defaults to 8 @@ -51,8 +53,6 @@ def detect_directory(model_path, weights_path, img_path, classes, output_path, :type conf_thres: float, optional :param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5 :type nms_thres: float, optional - :param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist. - :type robot_class_ids: List[int], optional """ dataloader = _create_data_loader(img_path, batch_size, img_size, n_cpu) model = load_model(model_path, weights_path) @@ -63,29 +63,36 @@ def detect_directory(model_path, weights_path, img_path, classes, output_path, output_path, conf_thres, nms_thres, - robot_class_ids=robot_class_ids + class_config.get_group_config() ) _draw_and_save_output_images( - img_detections, segmentations, imgs, img_size, output_path, classes) + img_detections, segmentations, imgs, img_size, output_path, class_config.get_ungrouped_det_class_names()) print(f"---- Detections were saved to: '{output_path}' ----") -def detect_image(model, image, img_size=416, conf_thres=0.5, nms_thres=0.5, robot_class_ids: Optional[List[int]] = None): +def detect_image(model, + image: np.ndarray, + img_size: int = 416, + conf_thres: float = 0.5, + nms_thres: float = 0.5, + group_config: Optional[GroupConfig] = None + ): """Inferences one image with model. :param model: Model for inference :type model: models.Darknet :param image: Image to inference - :type image: nd.array + :type image: np.ndarray :param img_size: Size of each image dimension for yolo, defaults to 416 - :type img_size: int, optional + :type img_size: int :param conf_thres: Object confidence threshold, defaults to 0.5 - :type conf_thres: float, optional + :type conf_thres: float :param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5 - :type nms_thres: float, optional - :param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist. - :type robot_class_ids: List[int], optional + :type nms_thres: float + :param group_config: GroupConfiguration for this model (optional, defaults to None) + :type group_config: Optional[GroupConfig] + :return: Detections on image with each detection in the format: [x1, y1, x2, y2, confidence, class], Segmentation as 2d numpy array with the coresponding class id in each cell :rtype: nd.array, nd.array """ @@ -105,13 +112,24 @@ def detect_image(model, image, img_size=416, conf_thres=0.5, nms_thres=0.5, robo # Get detections with torch.no_grad(): detections, segmentations = model(input_img) - detections = non_max_suppression(detections, conf_thres, nms_thres, robot_class_ids=robot_class_ids) + detections = non_max_suppression( + prediction=detections, + conf_thres=conf_thres, + iou_thres=nms_thres, + group_config=group_config + ) detections = rescale_boxes(detections[0], img_size, image.shape[0:2]) segmentations = rescale_segmentation(segmentations, image.shape[0:2]) return detections.numpy(), segmentations.cpu().detach().numpy() -def detect(model, dataloader, output_path, conf_thres, nms_thres, robot_class_ids: Optional[List[int]] = None): +def detect(model, + dataloader: DataLoader, + output_path: str, + conf_thres: float = 0.5, + nms_thres: float = 0.5, + group_config: Optional[GroupConfig] = None + ): """Inferences images with model. :param model: Model for inference @@ -121,11 +139,12 @@ def detect(model, dataloader, output_path, conf_thres, nms_thres, robot_class_id :param output_path: Path to output directory :type output_path: str :param conf_thres: Object confidence threshold, defaults to 0.5 - :type conf_thres: float, optional + :type conf_thres: float :param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5 - :type nms_thres: float, optional - :param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist. - :type robot_class_ids: List[int], optional + :type nms_thres: float + :param group_config: GroupConfig for this model (optional, defaults to None) + :type group_config: Optional[GroupConfig] + :return: List of detections. The coordinates are given for the padded image that is provided by the dataloader. Use `utils.rescale_boxes` to transform them into the desired input image coordinate system before its transformed by the dataloader), List of input image paths @@ -149,7 +168,12 @@ def detect(model, dataloader, output_path, conf_thres, nms_thres, robot_class_id # Get detections with torch.no_grad(): detections, segmentations = model(input_imgs) - detections = non_max_suppression(detections, conf_thres, nms_thres, robot_class_ids=robot_class_ids) + detections = non_max_suppression( + prediction=detections, + conf_thres=conf_thres, + iou_thres=nms_thres, + group_config=group_config + ) # Store image and detections img_detections.extend(detections) @@ -310,33 +334,24 @@ def run(): parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation") parser.add_argument("--conf_thres", type=float, default=0.5, help="Object confidence threshold") parser.add_argument("--nms_thres", type=float, default=0.4, help="IOU threshold for non-maximum suppression") - parser.add_argument("--multiple_robot_classes", action="store_true", - help="If multiple robot classes exist and nms shall be performed across all robot classes") + parser.add_argument("--class_config", type=str, default="class_config/default.yaml", help="Class configuration for evaluation") args = parser.parse_args() print(f"Command line arguments: {args}") - # Extract class names from file - classes = load_classes(args.classes)['detection'] # List of class names - - robot_class_ids = None - if args.multiple_robot_classes: - robot_class_ids = [] - for idx, c in enumerate(classes): - if "robot" in c: - robot_class_ids.append(idx) + class_names = ClassNames.load_from(args.classes) + class_config = ClassConfig.load_from(args.class_config, class_names) detect_directory( args.model, args.weights, args.images, - classes, + class_config, args.output, batch_size=args.batch_size, img_size=args.img_size, n_cpu=args.n_cpu, conf_thres=args.conf_thres, nms_thres=args.nms_thres, - robot_class_ids=robot_class_ids ) diff --git a/yoeo/test.py b/yoeo/test.py index d266572..224338d 100755 --- a/yoeo/test.py +++ b/yoeo/test.py @@ -1,7 +1,7 @@ #! /usr/bin/env python3 from __future__ import division, annotations -from typing import List, Optional +from typing import List, Optional, Tuple import argparse import tqdm @@ -14,16 +14,18 @@ from torch.autograd import Variable from yoeo.models import load_model -from yoeo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, \ +from yoeo.utils.utils import ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, \ print_environment_info, seg_iou from yoeo.utils.datasets import ListDataset from yoeo.utils.transforms import DEFAULT_TRANSFORMS +from yoeo.utils.dataclasses import ClassNames +from yoeo.utils.class_config import ClassConfig from yoeo.utils.parse_config import parse_data_config +from yoeo.utils.metric import Metric -def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_size=8, img_size=416, - n_cpu=8, iou_thres=0.5, conf_thres=0.5, nms_thres=0.5, verbose=True, - robot_class_ids: Optional[List[int]] = None): +def evaluate_model_file(model_path, weights_path, img_path, class_config, batch_size=8, img_size=416, + n_cpu=8, iou_thres=0.5, conf_thres=0.5, nms_thres=0.5, verbose=True): """Evaluate model on validation dataset. :param model_path: Path to model definition file (.cfg) @@ -32,8 +34,8 @@ def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_s :type weights_path: str :param img_path: Path to file containing all paths to validation images. :type img_path: str - :param class_names: Dict containing detection and segmentation class names - :type class_names: Dict + :param class_config: Object containing all class name related settings + :type class_config: TrainConfig :param batch_size: Size of each image batch, defaults to 8 :type batch_size: int, optional :param img_size: Size of each image dimension for yolo, defaults to 416 @@ -48,62 +50,82 @@ def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_s :type nms_thres: float, optional :param verbose: If True, prints stats of model, defaults to True :type verbose: bool, optional - :param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist. - :type robot_class_ids: List[int], optional :return: Returns precision, recall, AP, f1, ap_class """ dataloader = _create_validation_data_loader( img_path, batch_size, img_size, n_cpu) model = load_model(model_path, weights_path) - metrics_output, seg_class_ious = _evaluate( + metrics_output, seg_class_ious, secondary_metric = _evaluate( model, dataloader, - class_names, + class_config, img_size, iou_thres, conf_thres, nms_thres, - verbose, - robot_class_ids=robot_class_ids) - return metrics_output, seg_class_ious + verbose) + return metrics_output, seg_class_ious, secondary_metric -def print_eval_stats(metrics_output, seg_class_ious, class_names, verbose): +def print_eval_stats(metrics_output: Optional[Tuple[np.ndarray]], + seg_class_ious: List[np.float64], + secondary_metric: Optional[Metric], + class_config: ClassConfig, + verbose: bool + ): # Print detection statistics + print("#### Detection ####") if metrics_output is not None: precision, recall, AP, f1, ap_class = metrics_output if verbose: # Prints class AP and mean AP ap_table = [["Index", "Class", "AP"]] + class_names = class_config.get_squeezed_det_class_names() for i, c in enumerate(ap_class): - ap_table += [[c, class_names['detection'][c], "%.5f" % AP[i]]] + ap_table += [[c, class_names[c], "%.5f" % AP[i]]] print(AsciiTable(ap_table).table) print(f"---- mAP {AP.mean():.5f} ----") else: print("---- mAP not measured (no detections found by model) ----") + if secondary_metric is not None: + print("#### Detection - Secondary ####") + mbACC = secondary_metric.mbACC() + + if verbose: + classes = class_config.get_group_class_names() + mbACC_per_class = [secondary_metric.bACC(i) for i in range(len(classes))] + + sec_table = [["Index", "Class", "bACC"]] + for i, c in enumerate(classes): + sec_table += [[i, c, "%.5f" % mbACC_per_class[i]]] + print(AsciiTable(sec_table).table) + + print(f"---- mbACC {mbACC:.5f} ----") + + print("#### Segmentation ####") # Print segmentation statistics if verbose: # Print IoU per segmentation class seg_table = [["Index", "Class", "IoU"]] + class_names = class_config.get_seg_class_names() for i, iou in enumerate(seg_class_ious): - seg_table += [[i, class_names['segmentation'][i], "%.5f" % iou]] + seg_table += [[i, class_names[i], "%.5f" % iou]] print(AsciiTable(seg_table).table) # Print mean IoU mean_seg_class_ious = np.array(seg_class_ious).mean() print(f"----Average IoU {mean_seg_class_ious:.5f} ----") -def _evaluate(model, dataloader, class_names, img_size, iou_thres, conf_thres, nms_thres, verbose, - robot_class_ids: Optional[List[int]] = None): +def _evaluate(model, dataloader, class_config, img_size, iou_thres, conf_thres, nms_thres, verbose): """Evaluate model on validation dataset. :param model: Model to evaluate :type model: models.Darknet :param dataloader: Dataloader provides the batches of images with targets :type dataloader: DataLoader - :param class_names: Dict containing detection and segmentation class names - :type class_names: Dict + :param class_config: Object storing all class related settings + :type class_config: TrainConfig :param img_size: Size of each image dimension for yolo :type img_size: int :param iou_thres: IOU threshold required to qualify as detected @@ -114,8 +136,6 @@ def _evaluate(model, dataloader, class_names, img_size, iou_thres, conf_thres, n :type nms_thres: float :param verbose: If True, prints stats of model :type verbose: bool - :param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist. - :type robot_class_ids: List[int], optional :return: Returns precision, recall, AP, f1, ap_class """ model.eval() # Set model to evaluation mode @@ -127,9 +147,21 @@ def _evaluate(model, dataloader, class_names, img_size, iou_thres, conf_thres, n seg_ious = [] import time times = [] + + if class_config.classes_should_be_grouped(): + secondary_metric = Metric(len(class_config.get_group_ids())) + else: + secondary_metric = None + for _, imgs, bb_targets, mask_targets in tqdm.tqdm(dataloader, desc="Validating"): # Extract labels labels += bb_targets[:, 1].tolist() + + # If a subset of the detection classes should be grouped into one class for non-maximum suppression and the + # subsequent AP-computation, we need to group those class labels here. + if class_config.classes_should_be_grouped(): + labels = class_config.group(labels) + # Rescale target bb_targets[:, 2:] = xywh2xyxy(bb_targets[:, 2:]) bb_targets[:, 2:] *= img_size @@ -144,10 +176,20 @@ def _evaluate(model, dataloader, class_names, img_size, iou_thres, conf_thres, n yolo_outputs, conf_thres=conf_thres, iou_thres=nms_thres, - robot_class_ids=robot_class_ids + group_config=class_config.get_group_config() ) - sample_metrics += get_batch_statistics(yolo_outputs, bb_targets, iou_threshold=iou_thres) + sample_stat, secondary_stat = get_batch_statistics( + yolo_outputs, + bb_targets, + iou_threshold=iou_thres, + group_config=class_config.get_group_config() + ) + + sample_metrics += sample_stat + + if class_config.classes_should_be_grouped(): + secondary_metric += secondary_stat seg_ious.append(seg_iou(to_cpu(segmentation_outputs), mask_targets, model.num_seg_classes)) @@ -160,6 +202,7 @@ def _evaluate(model, dataloader, class_names, img_size, iou_thres, conf_thres, n # Concatenate sample statistics true_positives, pred_scores, pred_labels = [ np.concatenate(x, 0) for x in list(zip(*sample_metrics))] + yolo_metrics_output = ap_per_class( true_positives, pred_scores, pred_labels, labels) @@ -175,9 +218,9 @@ def seg_iou_mean_without_nan(seg_iou: List[float]) -> np.ndarray: seg_class_ious = [seg_iou_mean_without_nan(class_ious) for class_ious in list(zip(*seg_ious))] - print_eval_stats(yolo_metrics_output, seg_class_ious, class_names, verbose) + print_eval_stats(yolo_metrics_output, seg_class_ious, secondary_metric, class_config, verbose) - return yolo_metrics_output, seg_class_ious + return yolo_metrics_output, seg_class_ious, secondary_metric def _create_validation_data_loader(img_path, batch_size, img_size, n_cpu): @@ -221,8 +264,7 @@ def run(): parser.add_argument("--iou_thres", type=float, default=0.5, help="IOU threshold required to qualify as detected") parser.add_argument("--conf_thres", type=float, default=0.01, help="Object confidence threshold") parser.add_argument("--nms_thres", type=float, default=0.4, help="IOU threshold for non-maximum suppression") - parser.add_argument("--multiple_robot_classes", action="store_true", - help="If multiple robot classes exist and nms shall be performed across all robot classes") + parser.add_argument("--class_config", type=str, default="class_config/default.yaml", help="Class configuration for evaluation") args = parser.parse_args() print(f"Command line arguments: {args}") @@ -231,28 +273,22 @@ def run(): data_config = parse_data_config(args.data) # Path to file containing all images for validation valid_path = data_config["valid"] - class_names = load_classes(data_config["names"]) # Detection and segmentation class names - robot_class_ids = None - if args.multiple_robot_classes: - robot_class_ids = [] - for idx, c in enumerate(class_names["detection"]): - if "robot" in c: - robot_class_ids.append(idx) + class_names = ClassNames.load_from(data_config["names"]) # Detection and segmentation class names + class_config = ClassConfig.load_from(args.class_config, class_names) evaluate_model_file( args.model, args.weights, valid_path, - class_names, + class_config, batch_size=args.batch_size, img_size=args.img_size, n_cpu=args.n_cpu, iou_thres=args.iou_thres, conf_thres=args.conf_thres, nms_thres=args.nms_thres, - verbose=True, - robot_class_ids=robot_class_ids + verbose=args.verbose, ) diff --git a/yoeo/train.py b/yoeo/train.py index ef328cd..6f582fe 100755 --- a/yoeo/train.py +++ b/yoeo/train.py @@ -13,12 +13,12 @@ import torch.optim as optim from torch.autograd import Variable -from typing import List, Optional - from yoeo.models import load_model from yoeo.utils.logger import Logger -from yoeo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set +from yoeo.utils.utils import to_cpu, print_environment_info, provide_determinism, worker_seed_set from yoeo.utils.datasets import ListDataset +from yoeo.utils.dataclasses import ClassNames +from yoeo.utils.class_config import ClassConfig from yoeo.utils.augmentations import AUGMENTATION_TRANSFORMS from yoeo.utils.transforms import DEFAULT_TRANSFORMS from yoeo.utils.parse_config import parse_data_config @@ -80,8 +80,7 @@ def run(): parser.add_argument("--nms_thres", type=float, default=0.5, help="Evaluation: IOU threshold for non-maximum suppression") parser.add_argument("--logdir", type=str, default="logs", help="Directory for training log files (e.g. for TensorBoard)") parser.add_argument("--seed", type=int, default=-1, help="Makes results reproducable. Set -1 to disable.") - parser.add_argument("--multiple_robot_classes", action="store_true", - help="If multiple robot classes exist and nms shall be performed across all robot classes") + parser.add_argument("--class_config", type=str, default="class_config/default.yaml", help="Class configuration for evaluation") args = parser.parse_args() print(f"Command line arguments: {args}") @@ -98,14 +97,9 @@ def run(): data_config = parse_data_config(args.data) train_path = data_config["train"] valid_path = data_config["valid"] - class_names = load_classes(data_config["names"]) - robot_class_ids = None - if args.multiple_robot_classes: - robot_class_ids = [] - for idx, c in enumerate(class_names["detection"]): - if "robot" in c: - robot_class_ids.append(idx) + class_names = ClassNames.load_from(data_config["names"]) + class_config = ClassConfig.load_from(args.class_config, class_names) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -256,13 +250,12 @@ def run(): metrics_output = _evaluate( model, validation_dataloader, - class_names, + class_config=class_config, img_size=model.hyperparams['height'], iou_thres=args.iou_thres, conf_thres=args.conf_thres, nms_thres=args.nms_thres, verbose=args.verbose, - robot_class_ids=robot_class_ids ) if metrics_output is not None: @@ -274,6 +267,10 @@ def run(): ("validation/mAP", AP.mean()), ("validation/f1", f1.mean()), ("validation/seg_iou", np.array(seg_class_ious).mean())] + + if metrics_output[2] is not None: + evaluation_metrics.append(("validation/secondary_mbACC", metrics_output[2].mbACC())) + logger.list_of_scalars_summary(evaluation_metrics, epoch) diff --git a/yoeo/utils/class_config.py b/yoeo/utils/class_config.py new file mode 100644 index 0000000..8dc2cc4 --- /dev/null +++ b/yoeo/utils/class_config.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import yaml + +from typing import Dict, List, Any, Optional + +from yoeo.utils.dataclasses import ClassNames, GroupConfig + + +class ClassConfig: + def __init__(self, content: Dict[Any, Any], class_names: ClassNames): + self._det_class_names: List[str] = class_names.detection + self._seg_class_names: List[str] = class_names.segmentation + + self._class_names_to_group: List[str] = content["group_classes"] + self._group_surrogate_name: Optional[str] = content["surrogate_class"] + + self._ids_to_group: Optional[List[int]] = self._compute_group_ids() + self._grouped_det_class_names: List[str] = self._group_class_names() + + def _compute_group_ids(self) -> Optional[List[int]]: + """ + Given the list of detection class names and the list of class names that should be grouped into one class, + compute the ids of the latter classes, i.e. their position in the list of detection class names. + + :return: The ids of all class names that should be grouped into one class if there are any. None otherwise. + :rtype: Optional[List[int]] + """ + group_ids = None + + if self._class_names_to_group: + group_ids = [] + + for idx, class_name in enumerate(self._det_class_names): + if class_name in self._class_names_to_group: + group_ids.append(idx) + + return group_ids + + def _group_class_names(self) -> List[str]: + """ + Given the list of detection class names and the list of class names that should be grouped into one class, + compute a new list of class names in which all of the latter class names are removed and the surrogate class + name is inserted at the position of the first class of the classes that should be grouped. + + :return: A list of class names in which all class names that should be grouped are removed and the surrogate + class name is inserted as a surrogate for those classes + :rtype: List[str] + """ + + # Copy the list of detection class names + grouped_class_names = list(self._det_class_names) + + if self._ids_to_group: + # Insert the surrogate class name before the first to be grouped class name + grouped_class_names.insert(self.get_surrogate_id(), self._group_surrogate_name) + + # Remove all to be grouped class names + for name in self._class_names_to_group: + grouped_class_names.remove(name) + + return grouped_class_names + + def get_group_config(self) -> Optional[GroupConfig]: + """ + Get the current 'GroupConfig'. + + :return: The current 'GroupConfig' if neither 'self.get_group_ids()' nor 'self.get_surrogate_id()' is + 'None'. Return 'None' otherwise. + :rtype: Optional[GroupConfig] + """ + + group_ids = self.get_group_ids() + surrogate_id = self.get_surrogate_id() + + if group_ids is None or surrogate_id is None: + return None + else: + return GroupConfig(group_ids, surrogate_id) + + def get_group_class_names(self) -> List[str]: + """ + Get the class names of the classes that should be grouped together during evaluation + + :return: a list of class names that should be grouped together during evaluation + :rtype: List[str] + """ + return self._class_names_to_group + + def get_surrogate_id(self) -> Optional[int]: + """ + Get the id of the surrogate class in the list of grouped class names. If there are no classes to be grouped, + None is returned. + + :return: The id of the surrogate class in the list of grouped class names if there are classes that should be + grouped. None otherwise. + :rtype: Optional[int] + """ + return None if not self._ids_to_group else self._ids_to_group[0] + + def get_grouped_det_class_names(self) -> List[str]: + """ + Get the grouped list of detection class names. + + :return: The grouped list of detection class names. + :rtype: List[str] + """ + + return self._grouped_det_class_names + + def get_ungrouped_det_class_names(self) -> List[str]: + """ + Get the ungrouped list of detection class names. + + :return: The ungrouped list of detection class names. + :rtype: List[str] + """ + + return self._det_class_names + + def get_seg_class_names(self) -> List[str]: + """ + Get the list of segmentation class names. + + :return: The list of segmentation class names. + :rtype: List[str] + """ + + return self._seg_class_names + + def get_group_ids(self) -> Optional[List[int]]: + """ + Get the (ungrouped) ids of the class names that should be grouped into one class. + + :return: A list of ungrouped ids for the class names that should be grouped into one class if there are any. + None otherwise + :rtype: Optional[List[int]] + """ + return self._ids_to_group + + def get_surrogate_name(self) -> Optional[str]: + """ + Get the class name of the surrogate class if there are classes that should be grouped into one class. Return + None otherwise. + + :return: The name of the surrogate class if there are classes that should be grouped into one class. None + otherwise. + :rtype: Optional[List[str]] + """ + + return self._group_surrogate_name + + def classes_should_be_grouped(self) -> bool: + """ + Return true if there are classes that should be grouped into one class. Return false otherwise. + + :return: true if there are classes that should be grouped into on class. False otherwise. + :rtype: bool + """ + return self._ids_to_group is not None + + def group(self, labels: List[int]) -> List[int]: + """ + Group a list of class ids. Given a set of classes that should be grouped X, replace all class ids in X by + the surrogate id. + + :param labels: list of class ids to group. + :type labels: List[int] + + :return: grouped list of class ids where + :rtype: List[int] + """ + surrogate_id = self.get_surrogate_id() + + return [label if label not in self._ids_to_group else surrogate_id for label in labels] + + @classmethod + def load_from(cls, path: str, class_names: ClassNames) -> ClassConfig: + content = cls._read_yaml_file(path) + + return ClassConfig(content, class_names) + + @staticmethod + def _read_yaml_file(path: str) -> Dict[Any, Any]: + with open(path, "r") as f: + return yaml.safe_load(f) diff --git a/yoeo/utils/dataclasses.py b/yoeo/utils/dataclasses.py new file mode 100644 index 0000000..c2973cd --- /dev/null +++ b/yoeo/utils/dataclasses.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import yaml + +from dataclasses import dataclass +from typing import Any, Dict, List + + +@dataclass +class ClassNames: + detection: List[str] + segmentation: List[str] + + @classmethod + def load_from(cls, path: str) -> ClassNames: + file_content = cls._read_yaml_file(path) + class_names = cls._parse_yaml_file(file_content) + + return class_names + + @staticmethod + def _parse_yaml_file(content: Dict[Any, Any]) -> ClassNames: + return ClassNames(**content) + + @staticmethod + def _read_yaml_file(path: str) -> Dict[Any, Any]: + with open(path, "r") as f: + return yaml.safe_load(f) + + +@dataclass +class GroupConfig: + group_ids: List[int] + surrogate_id: int diff --git a/yoeo/utils/metric.py b/yoeo/utils/metric.py new file mode 100644 index 0000000..42add89 --- /dev/null +++ b/yoeo/utils/metric.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import numpy as np + + +class Metric: + """ + Metric object providing usefule metrics based on a confusion matrix + """ + + def __init__(self, n_classes): + self._n_classes = n_classes + self._conf_matrix = np.zeros(shape=(n_classes, n_classes)) + + def __add__(self, other: Metric): + assert type(other) == Metric, "cannot add other than Metric" + assert other._n_classes == self._n_classes, "Dimensions mismatch" + + m = Metric(self._n_classes) + m._conf_matrix = self._conf_matrix + other._conf_matrix + + return m + + def _tp(self, class_id: int) -> int: + return self._conf_matrix[class_id, class_id] + + def _fp(self, class_id: int) -> int: + return np.sum(self._conf_matrix[class_id, :]) - self._conf_matrix[class_id, class_id] + + def _fn(self, class_id: int) -> int: + return np.sum(self._conf_matrix[:, class_id]) - self._conf_matrix[class_id, class_id] + + def _tn(self, class_id: int) -> int: + return np.sum(self._conf_matrix) - self._tp(class_id) - self._fp(class_id) - self._fn(class_id) + + + def update(self, pred: int, target: int) -> None: + self._conf_matrix[pred, target] += 1 + + def merge(self, metric: Metric) -> None: + self._conf_matrix += metric._conf_matrix + + def reset(self) -> None: + self._conf_matrix = np.zeros(shape=(self._n_classes, self._n_classes)) + + def get_conf_matrix(self) -> np.ndarray: + return self._conf_matrix + + + def ACC(self, class_id: int) -> float: + denom = np.sum(self._conf_matrix) + return (self._tp(class_id) + self._tn(class_id)) / denom if denom != 0 else float("nan") + + def mACC(self) -> float: + return self._mean(self.ACC) + + def bACC(self, class_id: int) -> float: + return (self.TPR(class_id) + self.TNR(class_id)) / 2 + + def mbACC(self) -> float: + return self._mean(self.bACC) + + def _mean(self, fun) -> float: + return np.mean([fun(i) for i in range(self._n_classes)]) + + def PREC(self, class_id: int) -> float: + denom = (self._tp(class_id) + self._fp(class_id)) + return self._tp(class_id) / denom if denom != 0 else float("nan") + + def REC(self, class_id: int) -> float: + return self.TPR(class_id) + + def F1(self, class_id: int) -> float: + denom = (2 * self._tp(class_id) + self._fp(class_id) + self._fn(class_id)) + return 2 * self._tp(class_id) / denom if denom != 0 else float("nan") + + def TNR(self, class_id: int) -> float: + denom = self._fp(class_id) + self._tn(class_id) + return self._tn(class_id) / denom if denom != 0 else float("nan") + + def TPR(self, class_id: int) -> float: + denom = (self._tp(class_id) + self._fn(class_id)) + return self._tp(class_id) / denom if denom != 0 else float("nan") + \ No newline at end of file diff --git a/yoeo/utils/utils.py b/yoeo/utils/utils.py index a11bbe8..de62950 100644 --- a/yoeo/utils/utils.py +++ b/yoeo/utils/utils.py @@ -1,7 +1,5 @@ from __future__ import division, annotations -from typing import Tuple - import time import platform import tqdm @@ -11,8 +9,10 @@ import numpy as np import subprocess import random -from typing import List, Optional -import yaml +from typing import List, Optional, Tuple + +from yoeo.utils.dataclasses import GroupConfig +from yoeo.utils.metric import Metric def provide_determinism(seed=42): @@ -45,16 +45,6 @@ def to_cpu(tensor): return tensor.detach().cpu() -def load_classes(path: str) -> dict: - with open(path, 'r', encoding="utf-8") as fp: - names = yaml.load(fp, Loader=yaml.SafeLoader) - - assert "detection" in names.keys(), f"Missing key 'detection' in {path}" - assert "segmentation" in names.keys(), f"Missing key 'segmentation' in {path}" - - return names - - def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: @@ -298,11 +288,33 @@ def compute_ap(recall, precision): return ap -def get_batch_statistics(outputs, targets, iou_threshold): +def get_batch_statistics(outputs, + targets, + iou_threshold, + group_config: Optional[GroupConfig] = None + ) -> Tuple[List, Optional[Metric]]: + """ + Calculcate the batch statistics. If 'group_config' is not 'None', the contained classes will be grouped into one + class ('GroupConfig.surrogate_id') for batch statistics evaluation and evaluated separately on a secondary class + label. The statistics for the latter are returned as a 'Metric' object. If 'group_config' is None, no 'Metric' + object will be returned and the tuple will simply contain 'None' at the respective position. + + :return: The batch statistics, as well as an optional Metric object for the secondary class argument if + 'group_config' is not None + :rtype: Tuple[List, Optional[Metric]] + """ """ Compute true positives, predicted scores and predicted labels per sample """ batch_metrics = [] - for sample_i in range(len(outputs)): + grouping_active: bool = group_config is not None + + if grouping_active: + secondary_metric = Metric(len(group_config.group_ids)) + group_ids = torch.tensor(group_config.group_ids) + else: + secondary_metric = None + + for sample_i in range(len(outputs)): if outputs[sample_i] is None: continue @@ -311,16 +323,24 @@ def get_batch_statistics(outputs, targets, iou_threshold): pred_scores = output[:, 4] pred_labels = output[:, -1] + if grouping_active: + sec_pred_labels = compute_secondary_labels(pred_labels, group_ids) + pred_labels = group_primary_labels(pred_labels, group_ids, group_config.surrogate_id) + true_positives = np.zeros(pred_boxes.shape[0]) annotations = targets[targets[:, 0] == sample_i][:, 1:] target_labels = annotations[:, 0] if len(annotations) else [] + + if grouping_active and type(target_labels) is not list: + sec_target_labels = compute_secondary_labels(target_labels, group_ids) + target_labels = group_primary_labels(target_labels, group_ids, group_config.surrogate_id) + if len(annotations): detected_boxes = [] target_boxes = annotations[:, 1:] for pred_i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)): - # If targets are found break if len(detected_boxes) == len(annotations): break @@ -343,8 +363,36 @@ def get_batch_statistics(outputs, targets, iou_threshold): if iou >= iou_threshold and box_index not in detected_boxes: true_positives[pred_i] = 1 detected_boxes += [box_index] + + if grouping_active: + sec_pred_label = sec_pred_labels[pred_i] + + if pred_label in group_ids: + secondary_metric.update(sec_pred_label.int(), sec_target_labels[box_index].int()) + batch_metrics.append([true_positives, pred_scores, pred_labels]) - return batch_metrics + + return batch_metrics, secondary_metric + +def compute_secondary_labels(labels: torch.tensor, group_ids: torch.tensor) -> torch.tensor: + secondary_labels = labels.clone() + + # We replace the actual class labels with values from {0, ...} for classes that should be grouped into a + # single class. All other classes get the label -1. + for idx, squeeze_id in enumerate(group_ids): + # Replace label with value in {0, ...} + secondary_labels[labels == squeeze_id] = idx + + # Replace all other labels with -1 + secondary_labels[torch.logical_not(torch.isin(labels, group_ids))] = -1 + + return secondary_labels + +def group_primary_labels(labels: torch.tensor, group_ids: torch.tensor, surrogate_id: int) -> torch.tesor: + # Replace all primary labels that are contained in group_ids with the surrogate_id + labels[torch.isin(labels, group_ids)] = surrogate_id + + return labels def bbox_wh_iou(wh1, wh2): @@ -419,8 +467,11 @@ def box_area(box): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, - robot_class_ids: Optional[List[int]] = None): - """Performs Non-Maximum Suppression (NMS) on inference results + group_config: Optional[GroupConfig] = None): + """ + Performs Non-Maximum Suppression (NMS) on inference results. If 'group_config' is not 'None', the contained + classes will be treated as one class ('GroupConfig.surrogate_id') during non-maximum supression. + Returns: detections with shape: nx6 (x1, y1, x2, y2, conf, cls) """ @@ -437,8 +488,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non t = time.time() output = [torch.zeros((0, 6), device="cpu")] * prediction.shape[0] - if robot_class_ids: - robot_class_ids = torch.tensor(robot_class_ids, device=prediction.device, dtype=prediction.dtype) + if group_config: + group_ids = torch.tensor(group_config.group_ids, device=prediction.device, dtype=prediction.dtype) for xi, x in enumerate(prediction): # image index, image inference # Apply constraints @@ -476,13 +527,13 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non x = x[x[:, 4].argsort(descending=True)[:max_nms]] # Batched NMS - if robot_class_ids is None: + if group_config is None: c = x[:, 5:6] * max_wh # classes else: - # If multiple robot classes are present, all robot classes are treated as one class in order to perform - # nms across all classes and not per class. For this, all robot classes get the same offset. + # If for example multiple robot classes are present, all robot classes are treated as one class in order + # to perform nms across all classes and not per class. For this, all robot classes get the same offset. c = torch.clone(x[:, 5:6]) - c[torch.isin(c, robot_class_ids)] = robot_class_ids[0] + c[torch.isin(c, group_ids)] = group_config.surrogate_id c *= max_wh # boxes (offset by class), scores