-
Notifications
You must be signed in to change notification settings - Fork 7
/
gschnet_qm9_gap_relenergy.yaml
70 lines (65 loc) · 1.96 KB
/
gschnet_qm9_gap_relenergy.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# @package _global_
defaults:
- override /run: gschnet_run
- override /globals: gschnet_globals
- override /data: gschnet_qm9
- override /task: gschnet_task
- override /callbacks:
- checkpoint
- earlystopping
- lrmonitor
- progressbar
- modelsummary
- override /model: gschnet
- override /model/conditioning: gap_relenergy
- _self_
run:
path: ${run.work_dir}/models/qm9_${globals.name}
id: ${globals.id}
globals:
model_cutoff: 10.
prediction_cutoff: 5.
placement_cutoff: 1.7
use_covalent_radii: True
covalent_radius_factor: 1.1
atom_types: [1, 6, 7, 8, 9]
origin_type: 121
focus_type: 122
stop_type: 123
lr: 1e-4
draw_random_samples: 0
name: gap_relenergy
id: ${oc.env:SLURM_JOBID,${uuid:1}}_${oc.env:HOSTNAME,""}
data_workdir: null
cache_workdir: null
callbacks:
early_stopping:
patience: 25
progress_bar:
refresh_rate: 100
data:
batch_size: 5
num_train: 50000
num_val: 5000
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack_gschnet.transform.OrderByDistanceToOrigin
- _target_: schnetpack_gschnet.transform.GetRelativeAtomicEnergy
atom_types: ${globals.atom_types}
target_energy_name: energy_U0
max_train_points: 5000
- _target_: schnetpack_gschnet.transform.ConditionalGSchNetNeighborList
model_cutoff: ${globals.model_cutoff}
prediction_cutoff: ${globals.prediction_cutoff}
placement_cutoff: ${globals.placement_cutoff}
environment_provider: matscipy
use_covalent_radii: ${globals.use_covalent_radii}
covalent_radius_factor: ${globals.covalent_radius_factor}
- _target_: schnetpack_gschnet.transform.BuildAtomsTrajectory
centered: True
origin_type: ${globals.origin_type}
focus_type: ${globals.focus_type}
stop_type: ${globals.stop_type}
draw_random_samples: ${globals.draw_random_samples}
sort_idx_i: False
- _target_: schnetpack.transform.CastTo32