-
Notifications
You must be signed in to change notification settings - Fork 3
/
clip_model.py
135 lines (117 loc) · 5.73 KB
/
clip_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from torch import nn
from models import resnet, r2plus1d, avts
def generate_model(opt):
assert opt.clip_model in [
'resnet', 'r2plus1d', 'avts', 'i3d'
]
if opt.clip_model == 'resnet':
assert opt.clip_model_depth in [10, 18, 34, 50, 101, 152, 200]
from models.resnet import get_fine_tuning_parameters
from models.resnet import get_fine_tuning_parameters_layer_lr
if opt.clip_model_depth == 10:
model = resnet.resnet10(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model_depth == 18:
model = resnet.resnet18(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model_depth == 34:
model = resnet.resnet34(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model_depth == 50:
model = resnet.resnet50(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration,
no_last_fc=opt.no_last_fc)
elif opt.clip_model_depth == 101:
model = resnet.resnet101(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model_depth == 152:
model = resnet.resnet152(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model_depth == 200:
model = resnet.resnet200(
num_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size,
sample_duration=opt.sample_duration)
elif opt.clip_model.lower() in ['r2+1d', 'r2.5d', 'r2plus1d']:
from models.r2plus1d import get_fine_tuning_parameters, get_fine_tuning_parameters_layer_lr
print("Making R2+1D model, depth", opt.clip_model_depth)
model = r2plus1d.r2plus1d_34(num_classes=opt.n_classes)
elif opt.clip_model.lower() in ['avts']:
from models.avts import get_fine_tuning_parameters, get_fine_tuning_parameters_layer_lr
print("Making AVTS model")
model = avts.mc3_avts()
model.add_module("flatten", torch.nn.Flatten(1))
model.add_module("fc", nn.Linear(256, opt.n_classes))
elif opt.clip_model == 'i3d':
from models import i3res
if opt.clip_model_depth == 50:
print("Making I3D model, depth", opt.clip_model_depth)
model = i3res.i3res_50(sample_size=opt.sample_size, sample_duration=opt.sample_duration, num_classes=opt.n_classes)
elif opt.clip_model_depth == 34:
print("Making I3D model, depth", opt.clip_model_depth)
model = i3res.i3res_34(sample_size=opt.sample_size, sample_duration=opt.sample_duration, num_classes=opt.n_classes)
if opt.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if opt.gpu is not None:
torch.cuda.set_device(opt.gpu)
model.cuda(opt.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
opt.n_threads = int(opt.n_threads / opt.ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[
opt.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif opt.gpu is not None:
torch.cuda.set_device(opt.gpu)
model = model.cuda(opt.gpu)
else:
# DataParallel will divide and allocate batch_size to all available GPUs
model = torch.nn.DataParallel(model).cuda()
if opt.pretrain_path:
print('loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
model.load_state_dict(pretrain['state_dict'])
model.module.fc = nn.Linear(model.module.fc.in_features,
opt.n_finetune_classes)
model.module.fc = model.module.fc.cuda()
if opt.layer_lr is not None:
parameters = get_fine_tuning_parameters_layer_lr(model, opt.ft_begin_index, opt.layer_lr)
else:
parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
return model, parameters
# else:
# if opt.layer_lr is not None:
# parameters = get_fine_tuning_parameters_layer_lr(model, opt.ft_begin_index, opt.layer_lr)
# else:
# parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
# return model, parameters
return model