-
Notifications
You must be signed in to change notification settings - Fork 78
/
ban_vit-b16-in21k_mit-b0_512x512_40k_levircd.py
88 lines (83 loc) · 2.86 KB
/
ban_vit-b16-in21k_mit-b0_512x512_40k_levircd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
_base_ = [
'../_base_/models/ban_vit-b16.py',
'../common/standard_512x512_40k_levircd.py']
crop_size = (512, 512)
vit_checkpoint_file = 'pretrain/augreg_B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.pth' # noqa
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth' # noqa
model = dict(
pretrained=None,
asymetric_input=True,
encoder_resolution=dict(
size=(224, 224),
mode='bilinear'),
image_encoder=dict(
type='mmseg.VisionTransformer',
init_cfg=dict(
type='Pretrained', checkpoint=vit_checkpoint_file),
img_size=(224, 224),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(5, 8, 11),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
patch_bias=True,
interpolate_mode='bicubic',
frozen_exclude=[]),
decode_head=dict(
type='BitemporalAdapterHead',
ban_cfg=dict(
clip_channels=768,
fusion_index=[1, 2, 3],
side_enc_cfg=dict(
type='mmseg.MixVisionTransformer',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint),
in_channels=3,
embed_dims=32,
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 5, 8],
patch_sizes=[7, 3, 3, 3],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1)),
ban_dec_cfg=dict(
type='BAN_MLPDecoder',
in_channels=[32, 64, 160, 256],
channels=128,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
)),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)))
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'img_encoder': dict(lr_mult=0.1, decay_mult=1.0),
'norm': dict(decay_mult=0.),
'mask_decoder': dict(lr_mult=10.)
}),
loss_scale='dynamic',
clip_grad=dict(max_norm=0.01, norm_type=2))
train_dataloader = dict(batch_size=8, num_workers=8)
val_dataloader = dict(batch_size=1, num_workers=1)
# find_unused_parameters=True