Skip to content

Commit

Permalink
update moby pretrained model to deit small (alibaba#10)
Browse files Browse the repository at this point in the history
* add auto publish config

* [bugfix]: fix extract.py for benchmarks

* update moby pretrained model to deit small
  • Loading branch information
Cathy0908 authored Apr 18, 2022
1 parent f8aa493 commit b1f67f9
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
_base_ = 'configs/base.py'

# oss config only works when using oss
# sync local models and logs to oss
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
oss_io_config = dict(
ak_id='your oss ak id',
ak_secret='your oss ak secret',
hosts='your oss hosts',
buckets=['your oss buckets'])

# model settings
# 1920: merge 4 layers of features, open models/backbones/vit_transfomer_dynamic.py:311: self.forward_return_n_last_blocks
# 384: default
feature_num = 1920
model = dict(
type='Classification',
pretrained=None,
with_sobel=False,
backbone=dict(type='BenchMarkMLP', feature_num=feature_num),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=feature_num,
num_classes=1000))
# dataset settings
data_source_cfg = dict(type='SSLSourceImageNetFeature')

root_path = 'linear_eval/imagenet_features/'
dataset_type = 'ClsDataset'
train_pipeline = []
test_pipeline = []

data = dict(
imgs_per_gpu=2048, # total 2048*8=256, 8GPU linear cls
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_source=dict(
root_path=root_path, training=True, **data_source_cfg),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
root_path=root_path, training=False, **data_source_cfg),
pipeline=test_pipeline))

# additional hooks

eval_config = dict(interval=5, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))])
]

# optimizer
optimizer = dict(type='AdamW', lr=0.001, weight_decay=4e-5)

# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0.0, by_epoch=False)

checkpoint_config = dict(interval=5)
# runtime settings
total_epochs = 30
109 changes: 109 additions & 0 deletions configs/selfsup/moby/moby_deit_small_p16_4xb128_300e_tfrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
_base_ = '../../base.py'

# open oss config when using oss
# sync local models and logs to oss
# oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
# oss_io_config = dict(
# ak_id='your oss ak id',
# ak_secret='your oss ak secret',
# hosts='your oss hosts',
# buckets=['your oss buckets'])

# model settings
model = dict(
type='MoBY',
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
backbone=dict(
type='PytorchImageModelWrapper',
# model_name='pit_xs_distilled_224',
# model_name='swin_small_patch4_window7_224', # bad 16G memory will down
# model_name='swin_tiny_patch4_window7_224', # good 768
# model_name='swin_base_patch4_window7_224_in22k', # bad 16G memory will down
# model_name='vit_deit_tiny_distilled_patch16_224', # good 192,

# model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
model_name='dynamic_deit_small_p16', # 384
# model_name='xcit_small_12_p16', # 384
# model_name='shuffletrans_tiny_p4_w7_224', #768
# model_name = 'resnet50', # 2048
num_classes=0,
pretrained=False,
),
neck=dict(
type='MoBYMLP',
in_channels=384,
hid_channels=4096,
out_channels=256,
num_layers=2),
head=dict(
type='MoBYMLP',
in_channels=256,
hid_channels=4096,
out_channels=256,
num_layers=2))

dataset_type = 'DaliTFRecordMultiViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
mean = [x * 255 for x in img_norm_cfg['mean']]
std = [x * 255 for x in img_norm_cfg['std']]
train_pipeline = [
dict(type='DaliImageDecoder'),
dict(type='DaliRandomResizedCrop', size=224, random_area=(0.2, 1.0)),
dict(
type='DaliColorTwist',
prob=0.8,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.4),
dict(
type='DaliCropMirrorNormalize',
crop=[224, 224],
mean=mean,
std=std,
crop_pos_x=[0.0, 1.0],
crop_pos_y=[0.0, 1.0],
prob=0.5)
]

data = dict(
imgs_per_gpu=128, # total 128*4
workers_per_gpu=4,
train=dict(
type='DaliTFRecordMultiViewDataset',
data_source=dict(
type='ClsSourceImageNetTFRecord',
file_pattern='data/imagenet_tfrecord/train-*',
# root='data/imagenet_tfrecord/', # pick one of `file_pattern` and `root&list_file`
# list_file='data/imagenet_tfrecord/train_list.txt'
),
num_views=[1, 1],
pipelines=[train_pipeline, train_pipeline],
))

# optimizer
optimizer = dict(
type='AdamW',
lr=0.001,
weight_decay=0.05,
trans_weight_decay_set=['backbone']) # 0.001 for 512
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=5e-6,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.0001,
warmup_by_epoch=True)

checkpoint_config = dict(interval=10)

# runtime settings
total_epochs = 300

# export config
export = dict(export_neck=False)
checkpoint_sync_export = True
4 changes: 2 additions & 2 deletions docs/source/model_zoo_ssl.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Pretrained on **ImageNet** dataset.

| Config | Epochs | Download |
| ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
| [moby_resnet50_4xb128_100e](../../configs/selfsup/moby/moby_rn50_4xb128_100e_tfrecord.py) | 100 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/moby_r50/epoch_100.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/moby_r50/log.txt) |
| [moby_deit_small_p16_4xb128_300e](../../configs/selfsup/moby/moby_deit_small_p16_4xb128_300e_tfrecord.py) | 300 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/moby_deit_small_p16/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/moby_deit_small_p16/log.txt) |

### MoCo V2

Expand Down Expand Up @@ -53,7 +53,7 @@ For detailed usage of benchmark tools, please refer to benchmark [README.md](../
| --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | --------- | ------------------------------------------------------------ |
| SwAV | [swav_resnet50_8xb2048_20e_feature](../../benchmarks/selfsup/classification/imagenet/swav_r50_8xb2048_20e_feature.py) | [swav_resnet50_8xb32_200e](../../configs/selfsup/swav/swav_rn50_8xb32_200e_tfrecord.py) | 73.618 | [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/imagenet_linear_eval/swav_r50_linear_eval/20220216_101719.log.json) |
| DINO | [dino_deit_small_p16_8xb2048_20e_feature](../../benchmarks/selfsup/classification/imagenet/dino_deit_small_p16_8xb2048_20e_feature.py) | [dino_deit_small_p16_8xb32_100e](../../configs/selfsup/dino/dino_deit_small_p16_8xb32_100e_tfrecord.py) | 71.248 | [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/imagenet_linear_eval/dino_deit_small_linear_eval/20220215_141403.log.json) |
| MoBY r50 | [moby_resnet50_8xb2048_20e_feature](../../benchmarks/selfsup/classification/imagenet/moby_r50_8xb2048_20e_feature.py) | [moby_resnet50_4xb128_100e](../../configs/selfsup/moby/moby_rn50_4xb128_100e_tfrecord.py) | 78.392 | [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/imagenet_linear_eval/moby_r50_linear_eval/20220214_135816.log.json) |
| MoBY | [moby_deit_small_p16_8xb2048_30e_feature](../../benchmarks/selfsup/classification/imagenet/moby_deit_small_p16_8xb2048_30e_feature.py) | [moby_deit_small_p16_4xb128_300e](../../configs/selfsup/moby/moby_deit_small_p16_4xb128_300e_tfrecord.py) | 72.214 | [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/imagenet_linear_eval/moby_deit_small_p16_linear_eval/20220414_134929.log.json) |
| MoCo-v2 | [mocov2_resnet50_8xb2048_40e_feature](../../benchmarks/selfsup/classification/imagenet/mocov2_r50_8xb2048_40e_feature.py) | [mocov2_resnet50_8xb32_200e](../../configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_tfrecord.py) | 66.8 | [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/imagenet_linear_eval/mocov2_r50_linear_eval/20220214_143738.log.json) |

### ImageNet Finetuning
Expand Down

0 comments on commit b1f67f9

Please sign in to comment.