-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrepresentation.py
82 lines (73 loc) · 3.83 KB
/
representation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
import os
import sys
import re
import argparse
from types import SimpleNamespace
from datetime import datetime
from getpass import getuser
import numpy as np
import train
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True, type=str, help='path to the checkpoint log file')
parser.add_argument('--logdir', default='logs/evaluation', type=str, help='dir for the new log file')
parser.add_argument('--dataset', default=None, type=str, help='dataset (overwrites the one from the log)')
parser.add_argument('--seed', default=None, type=int, help='seed (overwrites the one from the log)')
parser.add_argument('--xtb', default=None, type=int, help='1 for xtb geometries 0 for dft (overwrites the one from the log)')
parser.add_argument('--xtb_subset', action='store_true', help='use the xtb subset (for checkpoints older than df87099b0)')
script_args = parser.parse_args()
run_dir = script_args.logdir
if not os.path.exists(run_dir):
os.makedirs(run_dir)
logname = f'{datetime.now().strftime("%y%m%d-%H%M%S.%f")}-{getuser()}'
logpath = os.path.join(run_dir, f'{logname}.log')
print(f"stdout/stderr to {logpath}")
sys.stdout = train.Logger(logpath=logpath, syspart=sys.stdout)
sys.stderr = train.Logger(logpath=logpath, syspart=sys.stderr)
with open(script_args.checkpoint, 'r') as f:
lines = f.readlines()
for line in lines:
if re.search('input args Namespace', line):
args = eval(line.strip().replace('input args Namespace', 'SimpleNamespace'))
break
for line in lines:
if re.search('and the best mae was in', line):
best_epoch = int(line.split()[-1])-1
for line in lines:
if re.search('Mean MAE across splits', line):
mae_logged = float(line.split()[-3])
print(args)
print()
args.logdir = script_args.logdir
args.experiment_name = None
args.wandb_name = None
args.num_epochs = best_epoch
args.checkpoint = script_args.checkpoint.replace('.log', '.best_checkpoint.pt')
args.eval_on_test_split = False
return_repr = True
if not hasattr(args, 'splitter'):
args.splitter = 'random'
if not hasattr(args, 'invariant'):
args.invariant = False
if not hasattr(args, 'train_frac'):
args.train_frac = 0.9
if not hasattr(args, 'xtb_subset'):
args.xtb_subset = script_args.xtb_subset
print(args)
print()
maes, rmses = train.train(run_dir, logname, None, None, {}, seed0=args.seed, print_repr=True,
device=args.device, num_epochs=args.num_epochs, checkpoint=args.checkpoint,
subset=args.subset, dataset=args.dataset, process=args.process,
verbose=args.verbose, radius=args.radius, max_neighbors=args.max_neighbors, sum_mode=args.sum_mode,
n_s=args.n_s, n_v=args.n_v, n_conv_layers=args.n_conv_layers, distance_emb_dim=args.distance_emb_dim,
graph_mode=args.graph_mode, dropout_p=args.dropout_p, random_baseline=args.random_baseline,
combine_mode=args.combine_mode, atom_mapping=args.atom_mapping, CV=args.CV, attention=args.attention,
noH=args.noH, two_layers_atom_diff=args.two_layers_atom_diff, rxnmapper=args.rxnmapper, reverse=args.reverse,
xtb=args.xtb, xtb_subset=args.xtb_subset,
splitter=args.splitter,
split_complexes=args.split_complexes, lr=args.lr, weight_decay=args.weight_decay,
eval_on_test_split=args.eval_on_test_split,
invariant=args.invariant,
training_fractions=[args.train_frac],
sweep=True)
print(f'delta MAE: {abs(mae_logged-np.mean(maes)):.2e}')