-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling_finetune.py
606 lines (518 loc) · 26 KB
/
modeling_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from einops import rearrange, repeat
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# me: support window mask
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, mask=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), mask=mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
"""
adapted from https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_pytorch.py
"""
# support cross attention
class GeneralAttention(nn.Module):
def __init__(
self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.kv = nn.Linear(dim if context_dim is None else context_dim, all_head_dim * 2, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, context=None):
B, T1, C = x.shape
q_bias, kv_bias = self.q_bias, None
if self.q_bias is not None:
kv_bias = torch.cat((torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, T1, self.num_heads, -1).transpose(1,2) # me: (B, H, T1, C//H)
kv = F.linear(input=x if context is None else context, weight=self.kv.weight, bias=kv_bias)
_, T2, _ = kv.shape
kv = kv.reshape(B, T2, 2, self.num_heads, -1).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple), me: (B, H, T2, C//H)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # me: (B, H, T1, T2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, T1, -1) # (B, H, T1, C//H) -> (B, T1, H, C//H) -> (B, T1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# local + global
class LGBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
attn_head_dim=None,
# new added
first_attn_type='self', third_attn_type='cross',
attn_param_sharing_first_third=False, attn_param_sharing_all=False,
no_second=False, no_third=False,
):
super().__init__()
assert first_attn_type in ['self', 'cross'], f"Error: invalid attention type '{first_attn_type}', expected 'self' or 'cross'!"
assert third_attn_type in ['self', 'cross'], f"Error: invalid attention type '{third_attn_type}', expected 'self' or 'cross'!"
self.first_attn_type = first_attn_type
self.third_attn_type = third_attn_type
self.attn_param_sharing_first_third = attn_param_sharing_first_third
self.attn_param_sharing_all = attn_param_sharing_all
# Attention layer
## perform local (intra-region) attention, update messenger tokens
## (local->messenger) or (local<->local, local<->messenger)
self.first_attn_norm0 = norm_layer(dim)
if self.first_attn_type == 'cross':
self.first_attn_norm1 = norm_layer(dim)
self.first_attn = GeneralAttention(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
## perform global (inter-region) attention on messenger tokens
## (messenger<->messenger)
self.no_second = no_second
if not no_second:
self.second_attn_norm0 = norm_layer(dim)
if attn_param_sharing_all:
self.second_attn = self.first_attn
else:
self.second_attn = GeneralAttention(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
## perform local (intra-region) attention to inject global information into local tokens
## (messenger->local) or (local<->local, local<->messenger)
self.no_third = no_third
if not no_third:
self.third_attn_norm0 = norm_layer(dim)
if self.third_attn_type == 'cross':
self.third_attn_norm1 = norm_layer(dim)
if attn_param_sharing_first_third or attn_param_sharing_all:
self.third_attn = self.first_attn
else:
self.third_attn = GeneralAttention(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
# FFN layer
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_0 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_0, self.gamma_1, self.gamma_2 = None, None, None
def forward(self, x, b):
"""
:param x: (B*N, S, C),
B: batch size
N: number of local regions
S: 1 + region size, 1: attached messenger token for each local region
C: feature dim
param b: batch size
:return: (B*N, S, C),
"""
bn = x.shape[0]
n = bn // b # number of local regions
if self.gamma_1 is None:
# Attention layer
## perform local (intra-region) self-attention
if self.first_attn_type == 'self':
x = x + self.drop_path(self.first_attn(self.first_attn_norm0(x)))
else: # 'cross'
x[:,:1] = x[:,:1] + self.drop_path(
self.first_attn(
self.first_attn_norm0(x[:,:1]), # (b*n, 1, c)
context=self.first_attn_norm1(x[:,1:]) # (b*n, s-1, c)
)
)
## perform global (inter-region) self-attention
if not self.no_second:
# messenger_tokens: representative tokens
messenger_tokens = rearrange(x[:,0], '(b n) c -> b n c', b=b) # attn on 'n' dim
messenger_tokens = messenger_tokens + self.drop_path(
self.second_attn(self.second_attn_norm0(messenger_tokens))
)
x[:,0] = rearrange(messenger_tokens, 'b n c -> (b n) c')
else: # for usage in the third attn
messenger_tokens = rearrange(x[:,0], '(b n) c -> b n c', b=b) # attn on 'n' dim
## perform local-global interaction
if not self.no_third:
if self.third_attn_type == 'self':
x = x + self.drop_path(self.third_attn(self.third_attn_norm0(x)))
else:
local_tokens = rearrange(x[:,1:], '(b n) s c -> b (n s) c', b=b)# NOTE: n merges into s (not b), (B, N*(S-1), D)
local_tokens = local_tokens + self.drop_path(
self.third_attn(
self.third_attn_norm0(local_tokens), # (b, n*(s-1), c)
context=self.third_attn_norm1(messenger_tokens) # (b, n*1, c)
)
)
x[:,1:] = rearrange(local_tokens, 'b (n s) c -> (b n) s c', n=n)
# FFN layer
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
raise NotImplementedError
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.tubelet_size = int(tubelet_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
# me: for more attention types
self.temporal_seq_len = num_frames // self.tubelet_size
self.spatial_num_patches = num_patches // self.temporal_seq_len
self.input_token_size = (num_frames // self.tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]))
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
init_scale=0.,
all_frames=16,
tubelet_size=2,
use_mean_pooling=True,
keep_temporal_dim=False, # do not perform temporal pooling, has higher priority than 'use_mean_pooling'
head_activation_func=None, # activation function after head fc, mainly for the regression task
attn_type='joint',
lg_region_size=(2, 2, 10), lg_first_attn_type='self', lg_third_attn_type='cross', # for local_global
lg_attn_param_sharing_first_third=False, lg_attn_param_sharing_all=False,
lg_classify_token_type='org', lg_no_second=False, lg_no_third=False,
):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
num_patches = self.patch_embed.num_patches
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# me: support more attention types
self.attn_type = attn_type
if attn_type == 'joint':
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
elif attn_type == 'local_global':
print(f"==> Note: Use 'local_global' for compute reduction (lg_region_size={lg_region_size},"
f"lg_first_attn_type={lg_first_attn_type}, lg_third_attn_type={lg_third_attn_type},"
f"lg_attn_param_sharing_first_third={lg_attn_param_sharing_first_third},"
f"lg_attn_param_sharing_all={lg_attn_param_sharing_all},"
f"lg_classify_token_type={lg_classify_token_type},"
f"lg_no_second={lg_no_second}, lg_no_third={lg_no_third})")
self.blocks = nn.ModuleList([
LGBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values,
first_attn_type=lg_first_attn_type, third_attn_type=lg_third_attn_type,
attn_param_sharing_first_third=lg_attn_param_sharing_first_third,
attn_param_sharing_all=lg_attn_param_sharing_all,
no_second=lg_no_second, no_third=lg_no_third,
)
for i in range(depth)])
# region tokens
self.lg_region_size = lg_region_size # (t, h, w)
self.lg_num_region_size = list(i//j for i,j in zip(self.patch_embed.input_token_size, lg_region_size)) # (nt, nh, nw)
num_regions = self.lg_num_region_size[0] * self.lg_num_region_size[1] * self.lg_num_region_size[2] # nt * nh * nw
print(f"==> Number of local regions: {num_regions} (size={self.lg_num_region_size})")
self.lg_region_tokens = nn.Parameter(torch.zeros(num_regions, embed_dim))
trunc_normal_(self.lg_region_tokens, std=.02)
# The token type used for final classification
self.lg_classify_token_type = lg_classify_token_type
assert lg_classify_token_type in ['org', 'region', 'all'], \
f"Error: wrong 'lg_classify_token_type' in local_global attention ('{lg_classify_token_type}'), expected 'org'/'region'/'all'!"
else:
raise NotImplementedError
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# me: add frame-level prediction support
self.keep_temporal_dim = keep_temporal_dim
# me: add head activation function support for regression task
if head_activation_func is not None:
if head_activation_func == 'sigmoid':
self.head_activation_func = nn.Sigmoid()
elif head_activation_func == 'relu':
self.head_activation_func = nn.ReLU()
elif head_activation_func == 'tanh':
self.head_activation_func = nn.Tanh()
else:
raise NotImplementedError
else: # default
self.head_activation_func = nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'part_tokens', 'lg_region_tokens'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
B, _, _ = x.size()
if self.pos_embed is not None:
x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
if self.attn_type == 'local_global':
# input: region partition
nt, t = self.lg_num_region_size[0], self.lg_region_size[0]
nh, h = self.lg_num_region_size[1], self.lg_region_size[1]
nw, w = self.lg_num_region_size[2], self.lg_region_size[2]
b = x.size(0)
x = rearrange(x, 'b (nt t nh h nw w) c -> b (nt nh nw) (t h w) c', nt=nt,nh=nh,nw=nw,t=t,h=h,w=w)
# add region (i.e., representative) tokens
region_tokens = repeat(self.lg_region_tokens, 'n c -> b n 1 c', b=b)
x = torch.cat([region_tokens, x], dim=2) # (b, nt*nh*nw, 1+thw, c)
x = rearrange(x, 'b n s c -> (b n) s c') # s = 1 + thw
# run through each block
for blk in self.blocks:
x = blk(x, b) # (b*n, s, c)
x = rearrange(x, '(b n) s c -> b n s c', b=b) # s = 1 + thw
# token for final classification
if self.lg_classify_token_type == 'region': # only use region tokens for classification
x = x[:,:,0] # (b, n, c)
elif self.lg_classify_token_type == 'org': # only use original tokens for classification
x = rearrange(x[:,:,1:], 'b n s c -> b (n s) c') # s = thw
else: # use all tokens for classification
x = rearrange(x, 'b n s c -> b (n s) c') # s = 1 + thw
else:
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.fc_norm is not None:
# me: add frame-level prediction support
if self.keep_temporal_dim:
x = rearrange(x, 'b (t hw) c -> b c t hw',
t=self.patch_embed.temporal_seq_len,
hw=self.patch_embed.spatial_num_patches)
# spatial mean pooling
x = x.mean(-1) # (B, C, T)
# temporal upsample: 8 -> 16, for patch embedding reduction
x = torch.nn.functional.interpolate(
x, scale_factor=self.patch_embed.tubelet_size,
mode='linear'
)
x = rearrange(x, 'b c t -> b t c')
return self.fc_norm(x)
else:
return self.fc_norm(x.mean(1))
else:
return x[:, 0]
def forward(self, x, save_feature=False):
x = self.forward_features(x)
if save_feature:
feature = x
x = self.head(x)
# me: add head activation function support
x = self.head_activation_func(x)
# me: add frame-level prediction support
if self.keep_temporal_dim:
x = x.view(x.size(0), -1) # (B,T,C) -> (B,T*C)
if save_feature:
return x, feature
else:
return x
@register_model
def vit_base_patch16_160(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=160,
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_base_dim512_no_depth_patch16_160(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=160,
patch_size=16, embed_dim=512, num_heads=8, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model