From f14353df9520bd0b44af60e96d4c91587fdfe07d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
<1286304229@qq.com>
Date: Mon, 9 Oct 2023 13:37:04 +0800
Subject: [PATCH] Add custom dataset of grounding dino (#11012)
---
configs/grounding_dino/README.md | 96 +++++++++++++++++++
...nding_dino_swin-t_finetune_8xb2_20e_cat.py | 56 +++++++++++
2 files changed, 152 insertions(+)
create mode 100644 configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py
diff --git a/configs/grounding_dino/README.md b/configs/grounding_dino/README.md
index 2c869adffc9..715b630cc79 100644
--- a/configs/grounding_dino/README.md
+++ b/configs/grounding_dino/README.md
@@ -74,3 +74,99 @@ Note:
1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being.
2. Finetune refers to fine-tuning on the COCO 2017 dataset. The R50 model is trained using 8 NVIDIA GeForce 3090 GPUs, while the remaining models are trained using 16 NVIDIA GeForce 3090 GPUs. The GPU memory usage is approximately 8.5GB.
3. Our performance is higher than the official model due to two reasons: we modified the initialization strategy and introduced a log scaler.
+
+## Custom Dataset
+
+To facilitate fine-tuning on custom datasets, we use a simple cat dataset as an example, as shown in the following steps.
+
+### 1. Dataset Preparation
+
+```shell
+cd mmdetection
+wget https://download.openmmlab.com/mmyolo/data/cat_dataset.zip
+unzip cat_dataset.zip -d data/cat/
+```
+
+cat dataset is a single-category dataset with 144 images, which has been converted to coco format.
+
+
+
+
+
+### 2. Config Preparation
+
+Due to the simplicity and small number of cat datasets, we use 8 cards to train 20 epochs, scale the learning rate accordingly, and do not train the language model, only the visual model.
+
+The Details of the configuration can be found in [grounding_dino_swin-t_finetune_8xb2_20e_cat](grounding_dino_swin-t_finetune_8xb2_20e_cat.py)
+
+### 3. Visualization and Evaluation
+
+Due to the Grounding DINO is an open detection model, so it can be detected and evaluated even if it is not trained on the cat dataset.
+
+The single image visualization is as follows:
+
+```shell
+cd mmdetection
+python demo/image_demo.py data/cat/images/IMG_20211205_120756.jpg configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py --weights https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth --texts cat.
+```
+
+
+
+
+
+The test dataset evaluation on single card is as follows:
+
+```shell
+python tools/test.py configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth
+```
+
+```text
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.867
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 1.000
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.931
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.867
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.903
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.907
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.907
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.907
+```
+
+### 4. Model Training and Visualization
+
+```shell
+./tools/dist_train.sh configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py 8 --work-dir cat_work_dir
+```
+
+The model will be saved based on the best performance on the test set. The performance of the best model (at epoch 16) is as follows:
+
+```text
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.905
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 1.000
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.923
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.905
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.927
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.937
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.937
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.937
+```
+
+We can find that after fine-tuning training, the training of the cat dataset is increased from 86.7 to 90.5.
+
+If we do single image inference visualization again, the result is as follows:
+
+```shell
+cd mmdetection
+python demo/image_demo.py data/cat/images/IMG_20211205_120756.jpg configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py --weights cat_work_dir/best_coco_bbox_mAP_epoch_16.pth --texts cat.
+```
+
+
+
+
diff --git a/configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py b/configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py
new file mode 100644
index 00000000000..c2265e86730
--- /dev/null
+++ b/configs/grounding_dino/grounding_dino_swin-t_finetune_8xb2_20e_cat.py
@@ -0,0 +1,56 @@
+_base_ = 'grounding_dino_swin-t_finetune_16xb2_1x_coco.py'
+
+data_root = 'data/cat/'
+class_name = ('cat', )
+num_classes = len(class_name)
+metainfo = dict(classes=class_name, palette=[(220, 20, 60)])
+
+model = dict(bbox_head=dict(num_classes=num_classes))
+
+train_dataloader = dict(
+ dataset=dict(
+ data_root=data_root,
+ metainfo=metainfo,
+ ann_file='annotations/trainval.json',
+ data_prefix=dict(img='images/')))
+
+val_dataloader = dict(
+ dataset=dict(
+ metainfo=metainfo,
+ data_root=data_root,
+ ann_file='annotations/test.json',
+ data_prefix=dict(img='images/')))
+
+test_dataloader = val_dataloader
+
+val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
+test_evaluator = val_evaluator
+
+max_epoch = 20
+
+default_hooks = dict(
+ checkpoint=dict(interval=1, max_keep_ckpts=1, save_best='auto'),
+ logger=dict(type='LoggerHook', interval=5))
+train_cfg = dict(max_epochs=max_epoch, val_interval=1)
+
+param_scheduler = [
+ dict(type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=30),
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epoch,
+ by_epoch=True,
+ milestones=[15],
+ gamma=0.1)
+]
+
+optim_wrapper = dict(
+ optimizer=dict(lr=0.00005),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'absolute_pos_embed': dict(decay_mult=0.),
+ 'backbone': dict(lr_mult=0.1),
+ 'language_model': dict(lr_mult=0),
+ }))
+
+auto_scale_lr = dict(base_batch_size=16)