forked from zouxiaochuan/code_ogblsc2022
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_edgenorm_cls.py
104 lines (86 loc) · 2.98 KB
/
predict_edgenorm_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import datasets as datasets
import torch.utils.data
from config import config
import models as models
import torch_utils
import torch.optim
import timm.scheduler
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import torch.distributed as dist
import pickle
import common_utils
run_id = 'dropout0.1_decay1_0.97_h32s32_hidden256_classnorm_lesse_train_2'
iepoch = 50
device = 'cuda:2'
def predict():
dataset = datasets.SimplePCQM4MDataset(
path=config['middle_data_path'], split_name='all', rotate=False, data_path_name='data', load_dist=True)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=config['batch_size'],
num_workers=8,
collate_fn=datasets.collate_fn,
shuffle=False
)
torch.cuda.set_device(device)
torch.cuda.empty_cache()
model = models.MoleculePairDistClassifier(config)
print('num of parameters: {0}'.format(np.sum([p.numel() for p in model.parameters()])))
model_save_path = os.path.join('models_valid', run_id)
sd = torch.load(os.path.join(model_save_path, f'epoch_{iepoch:03d}.pt'), map_location='cpu')
sd = {k[7:]: v for k, v in sd.items()}
# print(sd.keys())
model.load_state_dict(sd)
model.to(device)
model_save_path = os.path.join('models_valid', run_id)
scores_list = []
model.eval()
for batch in tqdm(loader):
graph, y = batch
graph = torch_utils.batch_to_device(graph, device)
with torch.no_grad():
scores = model(
graph['atom_feat_cate'],
graph['atom_feat_float'],
graph['atom_mask'],
graph['bond_index'],
graph['bond_feat_cate'],
graph['bond_feat_float'],
graph['bond_mask'],
graph['structure_feat_cate'],
graph['structure_feat_float'],
graph['triplet_feat_cate'])
pass
scores = torch.sigmoid(scores)
scores = scores.detach().cpu().numpy()
num_atom = graph['atom_mask'].sum(dim=1).detach().cpu().numpy().astype('int64')
for i, s in enumerate(scores):
scores_list.append(s[:num_atom[i], :num_atom[i], :])
pass
pass
return scores_list
pass
def process_fn(param):
p, i = param
data_path = os.path.join(config['middle_data_path'], 'data')
filename = os.path.join(data_path, format(i // 1000, '04d'), format(i, '07d') + '.pkl')
g, y = common_utils.load_obj(filename)
g['predict_pair_dist_cls'] = p
common_utils.save_obj((g, y), filename)
pass
def main():
preds = predict()
pool = mp.Pool()
params = [(preds[i], i) for i in range(len(preds))]
list(pool.imap_unordered(process_fn, tqdm(params), chunksize=1024))
pool.close()
pass
if __name__ == '__main__':
main()
pass