forked from wangxu-scu/DRSL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
55 lines (41 loc) · 1.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
# from model import TextDNN, ImageDNN, DNN, CNN
from model import Model
import numpy as np
# torch.manual_seed(torch_seed)
# np.random.seed(np_seed)
from custom_dataset import MyCustomDataset
from torch.utils.data import DataLoader
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# # XMedia dataset
# # class number: each sample of different modalities is classified in 200 independent classes
# dataset_config = {
# 'dataset_name': 'xmedianet_deep',
# 'class_number': 200
# }
# # Pascal dataset
dataset_config = {
'dataset_name': 'pascal_deep',
'class_number': 5
}
batch_size = {'train': 200, 'test': 200}
dataset = {x: MyCustomDataset(dataset=dataset_config['dataset_name'], state=x)
for x in ['train', 'test']}
is_shuffle = {'train': True, 'test': False}
dataloaders = {x: DataLoader(dataset[x], batch_size=batch_size[x],
shuffle=is_shuffle[x], num_workers=1)
for x in ['train', 'test']}
dataset_sizes = {x: len(dataset[x]) for x in ['train', 'test']}
model = Model(
input_dim_I=4096,
input_dim_T=300,
hidden_dim_I=1024,
hidden_dim_T=1024,
hidden_dim_R=1024,
output_dim_I=300,
output_dim_T=300,
output_dim_R=1
)
model.to(device)
import train
model = train.train2(model, dataloaders, device, dataset_sizes, num_epochs=20, retreival=True)