diff --git a/configs/mm_grounding_dino/coco/grounding_dino_swin-t_finetune_16xb4_1x_coco_48_17.py b/configs/mm_grounding_dino/coco/grounding_dino_swin-t_finetune_16xb4_1x_coco_48_17.py new file mode 100644 index 00000000000..43503fb8bea --- /dev/null +++ b/configs/mm_grounding_dino/coco/grounding_dino_swin-t_finetune_16xb4_1x_coco_48_17.py @@ -0,0 +1,158 @@ +_base_ = '../grounding_dino_swin-t_pretrain_obj365.py' + +data_root = 'data/coco/' +base_classes = ('person', 'bicycle', 'car', 'motorcycle', 'train', 'truck', + 'boat', 'bench', 'bird', 'horse', 'sheep', 'bear', 'zebra', + 'giraffe', 'backpack', 'handbag', 'suitcase', 'frisbee', + 'skis', 'kite', 'surfboard', 'bottle', 'fork', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'pizza', 'donut', 'chair', 'bed', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'microwave', 'oven', 'toaster', + 'refrigerator', 'book', 'clock', 'vase', 'toothbrush') +novel_classes = ('airplane', 'bus', 'cat', 'dog', 'cow', 'elephant', + 'umbrella', 'tie', 'snowboard', 'skateboard', 'cup', 'knife', + 'cake', 'couch', 'keyboard', 'sink', 'scissors') +all_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'kite', 'skateboard', 'surfboard', + 'bottle', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', + 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'bed', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'toothbrush') + +train_metainfo = dict(classes=base_classes) +test_metainfo = dict( + classes=all_classes, + base_classes=base_classes, + novel_classes=novel_classes) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities')) +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', backend_args=None, + imdecode_backend='pillow'), + dict( + type='FixScaleResize', + scale=(800, 1333), + keep_ratio=True, + backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'custom_entities', + 'tokens_positive')) +] + +train_dataloader = dict( + dataset=dict( + _delete_=True, + type='CocoDataset', + metainfo=train_metainfo, + data_root=data_root, + ann_file='zero-shot/instances_train2017_seen_2.json', + data_prefix=dict(img='train2017/'), + return_classes=True, + filter_cfg=dict(filter_empty_gt=False, min_size=32), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='CocoDataset', + metainfo=test_metainfo, + data_root=data_root, + ann_file='zero-shot/instances_val2017_all_2.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + return_classes=True, + )) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='OVCocoMetric', + ann_file=data_root + 'zero-shot/instances_val2017_all_2.json', + metric='bbox', + format_only=False) +test_evaluator = val_evaluator + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00005, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'backbone': dict(lr_mult=0.1), + # 'language_model': dict(lr_mult=0), + })) + +# learning policy +max_epochs = 12 +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] +train_cfg = dict(max_epochs=max_epochs, val_interval=1) + +default_hooks = dict( + checkpoint=dict( + max_keep_ckpts=1, save_best='coco/novel_ap50', rule='greater')) + +load_from = 'epoch_30.pth' diff --git a/configs/mm_grounding_dino/lvis/grounding_dino_swin-t_finetune_16xb4_1x_lvis_866_337.py b/configs/mm_grounding_dino/lvis/grounding_dino_swin-t_finetune_16xb4_1x_lvis_866_337.py new file mode 100644 index 00000000000..07d129c39b8 --- /dev/null +++ b/configs/mm_grounding_dino/lvis/grounding_dino_swin-t_finetune_16xb4_1x_lvis_866_337.py @@ -0,0 +1,120 @@ +_base_ = '../grounding_dino_swin-t_pretrain_obj365.py' + +data_root = 'data/lvis/' + +model = dict(test_cfg=dict( + max_per_img=300, + chunked_size=40, +)) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/lvis/annotations/lvis_v1_label_map_norare.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] + +train_dataloader = dict( + dataset=dict( + _delete_=True, + type='ClassBalancedDataset', + oversample_thr=1e-3, + dataset=dict( + type='ODVGDataset', + data_root=data_root, + need_text=False, + label_map_file='annotations/lvis_v1_label_map_norare.json', + ann_file='annotations/lvis_v1_train_od_norare.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False, min_size=32), + return_classes=True, + pipeline=train_pipeline))) + +val_dataloader = dict( + dataset=dict( + data_root=data_root, + type='LVISV1Dataset', + ann_file='annotations/lvis_v1_minival_inserted_image_name.json', + data_prefix=dict(img=''))) +test_dataloader = val_dataloader + +val_evaluator = dict( + _delete_=True, + type='LVISFixedAPMetric', + ann_file=data_root + + 'annotations/lvis_v1_minival_inserted_image_name.json') +test_evaluator = val_evaluator + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.00005, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'backbone': dict(lr_mult=0.1), + # 'language_model': dict(lr_mult=0), + })) + +# learning policy +max_epochs = 12 +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] +train_cfg = dict(max_epochs=max_epochs, val_interval=3) + +default_hooks = dict( + checkpoint=dict( + max_keep_ckpts=3, save_best='lvis_fixed_ap/AP', rule='greater')) + +load_from = 'epoch_30.pth' diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index 4b61894dbbb..8ad040cf6ff 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -16,6 +16,7 @@ from .lvis_metric import LVISMetric from .mot_challenge_metric import MOTChallengeMetric from .openimages_metric import OpenImagesMetric +from .ov_coco_metric import OVCocoMetric from .refexp_metric import RefExpMetric from .refseg_metric import RefSegMetric from .reid_metric import ReIDMetrics @@ -29,5 +30,6 @@ 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric', 'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric', 'COCOCaptionMetric', 'SemSegMetric', 'RefSegMetric', 'RefExpMetric', - 'gRefCOCOMetric', 'DODCocoMetric', 'DumpODVGResults', 'Flickr30kMetric' + 'gRefCOCOMetric', 'DODCocoMetric', 'DumpODVGResults', 'Flickr30kMetric', + 'OVCocoMetric' ] diff --git a/mmdet/evaluation/metrics/ov_coco_metric.py b/mmdet/evaluation/metrics/ov_coco_metric.py new file mode 100644 index 00000000000..08cb9025149 --- /dev/null +++ b/mmdet/evaluation/metrics/ov_coco_metric.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import Dict + +import numpy as np +from mmengine.fileio import load +from mmengine.logging import MMLogger +from terminaltables import AsciiTable + +from mmdet.datasets.api_wrappers import COCO, COCOeval, COCOevalMP +from mmdet.registry import METRICS +from .coco_metric import CocoMetric + + +@METRICS.register_module() +class OVCocoMetric(CocoMetric): + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # split gt and prediction list + gts, preds = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + if self._coco_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=outfile_prefix) + self._coco_api = COCO(coco_json_path) + + # handle lazy init + if self.cat_ids is None: + self.cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + self.base_cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['base_classes']) + self.novel_cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['novel_classes']) + + if self.img_ids is None: + self.img_ids = self._coco_api.get_img_ids() + + # convert predictions to coco format and dump to json file + result_files = self.results2json(preds, outfile_prefix) + + eval_results = OrderedDict() + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + for metric in self.metrics: + logger.info(f'Evaluating {metric}...') + + # TODO: May refactor fast_eval_recall to an independent metric? + # fast eval recall + if metric == 'proposal_fast': + ar = self.fast_eval_recall( + preds, self.proposal_nums, self.iou_thrs, logger=logger) + log_msg = [] + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + logger.info(log_msg) + continue + + # evaluate proposal, bbox and segm + iou_type = 'bbox' if metric == 'proposal' else metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + coco_dt = self._coco_api.loadRes(predictions) + + except IndexError: + logger.error( + 'The testing results of the whole dataset is empty.') + break + + if self.use_mp_eval: + coco_eval = COCOevalMP(self._coco_api, coco_dt, iou_type) + else: + coco_eval = COCOeval(self._coco_api, coco_dt, iou_type) + + coco_eval.params.catIds = self.cat_ids + coco_eval.params.imgIds = self.img_ids + coco_eval.params.maxDets = list(self.proposal_nums) + coco_eval.params.iouThrs = self.iou_thrs + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + metric_items = self.metric_items + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item "{metric_item}" is not supported') + + if metric == 'proposal': + coco_eval.params.useCats = 0 + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if metric_items is None: + metric_items = [ + 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', + 'AR_m@1000', 'AR_l@1000' + ] + + for item in metric_items: + val = float( + f'{coco_eval.stats[coco_metric_names[item]]:.3f}') + eval_results[item] = val + else: + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if self.classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = coco_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, cat_id in enumerate(self.cat_ids): + t = [] + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self._coco_api.loadCats(cat_id)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{nm["name"]}') + t.append(f'{round(ap, 3)}') + eval_results[f'{nm["name"]}_precision'] = round(ap, 3) + + # indexes of IoU @50 and @75 + for iou in [0, 5]: + precision = precisions[iou, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + + # indexes of area of small, median and large + for area in [1, 2, 3]: + precision = precisions[:, :, idx, area, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + results_per_category.append(tuple(t)) + + num_columns = len(results_per_category[0]) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = [ + 'category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', + 'mAP_m', 'mAP_l' + ] + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + logger.info('\n' + table.table) + + # ------------get novel_ap50 and base_ap50--------- + precisions = coco_eval.eval['precision'] + assert len(self.cat_ids) == precisions.shape[2] + base_inds, novel_inds = [], [] + + for idx, catId in enumerate(self.cat_ids): + if catId in self.base_cat_ids: + base_inds.append(idx) + if catId in self.novel_cat_ids: + novel_inds.append(idx) + + base_ap = precisions[:, :, base_inds, 0, -1] + novel_ap = precisions[:, :, novel_inds, 0, -1] + base_ap50 = precisions[0, :, base_inds, 0, -1] + novel_ap50 = precisions[0, :, novel_inds, 0, -1] + + eval_results['base_ap'] = np.mean( + base_ap[base_ap > -1]) if len( + base_ap[base_ap > -1]) else -1 + eval_results['novel_ap'] = np.mean( + novel_ap[novel_ap > -1]) if len( + novel_ap[novel_ap > -1]) else -1 + eval_results['base_ap50'] = np.mean( + base_ap50[base_ap50 > -1]) if len( + base_ap50[base_ap50 > -1]) else -1 + eval_results['novel_ap50'] = np.mean( + novel_ap50[novel_ap50 > -1]) if len( + novel_ap50[novel_ap50 > -1]) else -1 + # ------------get novel_ap50 and base_ap50--------- + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = coco_eval.stats[coco_metric_names[metric_item]] + eval_results[key] = float(f'{round(val, 3)}') + + ap = coco_eval.stats[:6] + logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} ' + f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results