Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix radial/angular predictors + other changes #123

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9cbecb6
things on delta
songk42 May 2, 2024
b95586e
fixed generation
songk42 May 3, 2024
6e8959b
try adding a discretized angular predictor
May 4, 2024
f9b032e
fix outdated import
songk42 May 4, 2024
0e34c16
attempt to make the discrete stuff actually discrete
May 6, 2024
8699687
syncing
songk42 May 23, 2024
eb03079
slurm
songk42 May 23, 2024
aaa3f81
toggle btwn discretized + continuous loss
May 12, 2024
c6e2dd2
most recent changes (scripts, etc)
May 2, 2024
0236ca5
add engaging to root dir
May 2, 2024
7a9abe4
fix some bugs
songk42 May 4, 2024
da40e04
max targets 6
songk42 May 7, 2024
9e77695
add radial/angular logits to eval prediction; fix conditional generat…
May 10, 2024
afdcf97
Create LICENSE
ameya98 May 26, 2024
6d59596
reset diverging branches
May 27, 2024
2593ab4
remove angular continuous param
May 27, 2024
eec3740
Merge branch 'main' of github.com:atomicarchitects/symphony into tmqm
May 28, 2024
a8ca4fb
save progress
Jun 25, 2024
a061857
Merge branch 'main' of github.com:atomicarchitects/symphony into tmqm
Jun 25, 2024
9bc1d84
minimal changes - for pulling
Jun 30, 2024
ef7149b
Merge branch 'main' of github.com:atomicarchitects/symphony into tmqm
Jun 30, 2024
a754eb4
make things run for single fragment
Jun 30, 2024
2a891e1
add jit back!
Jun 30, 2024
c0dd1c5
train fix
Jul 2, 2024
ac6487e
Merge branch 'main' of github.com:atomicarchitects/symphony into loss…
Jul 2, 2024
5493d0d
config
Jul 2, 2024
667c7b2
Merge branch 'main' of github.com:atomicarchitects/symphony into tmqm
Jul 2, 2024
7c4ae26
extra merge changes
Jul 2, 2024
96e5b4a
make this runnable
Jul 3, 2024
e28a2bc
Merge branch 'tmqm' of github.com:atomicarchitects/symphony into loss…
Jul 3, 2024
a880a07
run commit
Jul 3, 2024
73b0dcb
debugging
Jul 17, 2024
7968cd1
Merge branch 'main' of github.com:atomicarchitects/symphony into loss…
Jul 17, 2024
3d0728b
more small changes
Oct 28, 2024
450a178
fix angular dist code
Oct 29, 2024
f25d43b
add platonic solid support
Oct 29, 2024
bedaead
change back to the old generation code
Oct 29, 2024
d1bcdce
configs
Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions analyses/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from configs import root_dirs

try:
from symphony.data import input_pipeline_tf
from symphony.data import input_pipeline
import tensorflow as tf

tf.config.experimental.set_visible_devices([], "GPU")
Expand Down Expand Up @@ -253,14 +253,7 @@ def load_from_workdir(
# Check that the config was loaded correctly.
assert config is not None
config = ml_collections.ConfigDict(config)
if 'max_targets_per_graph' in config:
config.root_dir = root_dirs.get_root_dir(
config.dataset, config.get("fragment_logic", "nn"), config.max_targets_per_graph
)
else:
config.root_dir = root_dirs.get_root_dir(
config.dataset, config.get("fragment_logic", "nn")
)
config.root_dir = root_dirs.get_root_dir(config.dataset)

# Mimic what we do in train.py.
rng = jax.random.PRNGKey(config.rng_seed)
Expand Down Expand Up @@ -295,31 +288,35 @@ def load_from_workdir(
else:
if init_graphs is None:
logging.info("Initializing dummy model with init_graphs from dataloader")
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)
train_iter = datasets["train"].as_numpy_iterator()
datasets = input_pipeline.get_datasets(dataset_rng, config)
train_iter = datasets["train"]
init_graphs = next(train_iter)
else:
logging.info("Initializing dummy model with provided init_graphs")

params = jax.jit(net.init)(init_rng, init_graphs)

tx = train.create_optimizer(config)
dummy_state = train_state.TrainState.create(
apply_fn=net.apply, params=params, tx=tx
)
# dummy_state = train_state.TrainState.create(
# apply_fn=net.apply, params=params, tx=tx
# )

# Load the actual values.
checkpoint_dir = os.path.join(workdir, "checkpoints")
ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=5)
data = ckpt.restore({"best_state": dummy_state, "metrics_for_best_state": None})
best_state = jax.tree_util.tree_map(jnp.asarray, data["best_state"])
data = ckpt.restore_dict()["state"]
best_state = jax.tree_util.tree_map(
jnp.asarray,
train_state.TrainState.create(
apply_fn=net.apply, params=data["best_params"], tx=tx
))
best_state_in_eval_mode = best_state.replace(apply_fn=eval_net.apply)

return (
config,
best_state,
best_state_in_eval_mode,
cast_keys_as_int(data["metrics_for_best_state"]),
cast_keys_as_int(data["metrics_for_best_params"]),
)


Expand Down
46 changes: 26 additions & 20 deletions analyses/conditional_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,13 @@

from absl import flags, app
import analyses.generate_molecules as generate_molecules
from symphony.data.datasets import qm9

from symphony.data.datasets import tmqm
from configs.root_dirs import get_root_dir
from symphony import datatypes


workdir = "/home/ameyad/spherical-harmonic-net/workdirs/qm9_bessel_embedding_attempt6_edm_splits/e3schnet_and_nequip/interactions=3/l=5/position_channels=2/channels=64"
outputdir = "conditional_generation"
beta_species = 1.0
beta_position = 1.0
step = "7530000"
num_seeds_per_chunk = 25
max_num_atoms = 35
visualize = False
num_mols = 1000

all_mols = qm9.load_qm9("../qm9_data", use_edm_splits=True, check_molecule_sanity=False)
test_mols = all_mols[-num_mols:]
train_mols = all_mols[:num_mols]


def get_fragment_list(mols: Sequence[ase.Atoms], num_mols: int):
Expand All @@ -46,36 +36,42 @@ def get_fragment_list(mols: Sequence[ase.Atoms], num_mols: int):


def main(unused_argv: Sequence[str]):
radial_cutoff = 5.0
beta_species = 1.0
beta_position = 1.0
step = flags.FLAGS.step
num_seeds_per_chunk = 1
max_num_atoms = 200
max_num_steps = 10
num_mols = 20
max_num_atoms = 50
num_mols = 500
avg_neighbors_per_atom = 32

atomic_numbers = np.arange(1, 81)

all_mols = tmqm.load_tmqm("../tmqm_data")
mols_by_split = {"train": all_mols[:num_mols], "test": all_mols[-num_mols:]}

for split, split_mols in mols_by_split.items():
# Ensure that the number of molecules is a multiple of num_seeds_per_chunk.
mol_list = get_fragment_list(split_mols, num_mols)
mol_list = split_mols[
: num_seeds_per_chunk * (len(split_mols) // num_seeds_per_chunk)
mol_list = mol_list[
: num_seeds_per_chunk * (len(mol_list) // num_seeds_per_chunk)
]
print(f"Number of fragments for {split}: {len(mol_list)}")

gen_mol_list = generate_molecules.generate_molecules(
gen_mol_list = generate_molecules.generate_molecules_from_workdir(
flags.FLAGS.workdir,
os.path.join(flags.FLAGS.outputdir, split),
radial_cutoff,
beta_species,
beta_position,
step,
flags.FLAGS.steps_for_weight_averaging,
len(mol_list),
num_seeds_per_chunk,
mol_list,
max_num_atoms,
max_num_steps,
avg_neighbors_per_atom,
atomic_numbers,
flags.FLAGS.visualize,
)

Expand All @@ -101,4 +97,14 @@ def main(unused_argv: Sequence[str]):
"best",
"Step number to load model from. The default corresponds to the best model.",
)
flags.DEFINE_list(
"steps_for_weight_averaging",
None,
"Steps to average parameters over. If None, the model at the given step is used.",
)
flags.DEFINE_bool(
"seed_structure",
False,
"Add initial atom of the missing element to structure"
)
app.run(main)
3 changes: 1 addition & 2 deletions analyses/elements.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
1 H 0.46 1.20 0.200 1.00000 1.00000 1.00000
1 D 0.46 1.20 0.200 0.80000 0.80000 1.00000
2 He 1.22 1.40 1.220 0.85100 1.00000 1.00000
3 Li 1.57 1.40 0.590 0.80000 0.50200 1.00000
4 Be 1.12 1.40 0.270 0.76100 1.00000 0.00000
Expand Down Expand Up @@ -94,4 +93,4 @@
93 Np 1.56 2.16 0.750 0.00000 0.50200 1.00000
94 Pu 1.64 2.16 0.860 0.00000 0.42000 1.00000
95 Am 1.73 2.16 0.975 0.32900 0.36100 0.94900
96 XX 0.80 1.00 0.800 0.47100 0.36100 0.89000
96 XX 0.80 1.00 0.800 0.47100 0.36100 0.89000
Loading