Skip to content

Commit

Permalink
Merge pull request #302 from choderalab/dev-add-regression-plots
Browse files Browse the repository at this point in the history
add regression plots and error histograms for training/validation/testing
  • Loading branch information
wiederm authored Nov 4, 2024
2 parents f20c713 + 8ca7fe3 commit 2a6b208
Show file tree
Hide file tree
Showing 10 changed files with 915 additions and 84 deletions.
8 changes: 8 additions & 0 deletions modelforge/potential/aimnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def compute_properties(
}
)["per_atom_charge"]

# check that none of the tensors are NaN
if torch.isnan(atomic_embedding).any():
raise ValueError("NaN values detected in atomic embeddings.")
if torch.isnan(partial_charges).any():
raise ValueError("NaN values detected in partial charges.")

return {
"per_atom_scalar_representation": atomic_embedding,
"atomic_subsystem_indices": data.atomic_subsystem_indices,
Expand Down Expand Up @@ -308,6 +314,8 @@ def calculate_contributions(

# Accumulate the vector contributions using index_add_
vector_contributions.index_add_(0, idx_j, vector_prot_step2)
if torch.isnan(vector_contributions).any():
raise ValueError("NaN values detected in vector_contributions.")

# Step 3: Compute the Euclidean Norm for each atom
vector_norms = torch.norm(
Expand Down
8 changes: 8 additions & 0 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from modelforge.potential.neighbors import PairlistData


def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)


def triu_index(number_of_atoms: int) -> torch.Tensor:
"""
Generate a tensor representing the upper triangular indices for species
Expand Down Expand Up @@ -778,6 +784,8 @@ def __init__(
lookup_tensor[atomic_number] = index

self.register_buffer("lookup_tensor", lookup_tensor)
# Apply the custom weight initialization
self.apply(init_params)

def compute_properties(
self,
Expand Down
6 changes: 3 additions & 3 deletions modelforge/tests/data/runtime_defaults/runtime.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ experiment_name = "{potential_name}_{dataset_name}"
local_cache_dir = "./cache"
accelerator = "cpu"
number_of_nodes = 1
devices = 1 #[0,1,2,3]
devices = 1 #[0,1,2,3]
checkpoint_path = "None"
simulation_environment = "PyTorch"
log_every_n_steps = 1
verbose = true
log_every_n_steps = 50
verbose = true
1 change: 1 addition & 0 deletions modelforge/tests/data/training_defaults/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ shift_center_of_mass_to_origin = false
batch_size = 128
lr = 5e-4
monitor = "val/per_system_energy/rmse" # Common monitor key
plot_frequency = 1
# ------------------------------------------------------------ #
[training.experiment_logger]
logger_name = "tensorboard" # this will set which logger to use
Expand Down
99 changes: 86 additions & 13 deletions modelforge/tests/test_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def setup_two_methanes():
requires_grad=True,
device=device,
)
# Specify the translation vector
translation_vector = torch.tensor([1.0, 1.0, 1.0], device=device)
# Translate the second "molecule" without in-place modification
translated_coordinates = (
coordinates.clone()
) # Clone the tensor to avoid in-place modification
translated_coordinates[1] = translated_coordinates[1] + translation_vector

print(translated_coordinates)

# In periodic table, C = 6 and H = 1
mf_species = torch.tensor([6, 1, 1, 1, 1, 6, 1, 1, 1, 1], device=device)
ani_species = torch.tensor([[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]], device=device)
Expand All @@ -84,11 +94,14 @@ def setup_two_methanes():

nnp_input = NNPInput(
atomic_numbers=atomic_numbers,
positions=torch.cat((coordinates[0], coordinates[1]), dim=0) / 10,
positions=torch.cat(
(translated_coordinates[0], translated_coordinates[1]), dim=0
)
/ 10,
atomic_subsystem_indices=atomic_subsystem_indices,
per_system_total_charge=torch.tensor([0.0, 0.0]),
)
return ani_species, coordinates, device, nnp_input
return ani_species, translated_coordinates, device, nnp_input


@pytest.mark.xfail
Expand All @@ -110,11 +123,11 @@ def test_ani():
# calculate energy for methane
energy = model((species, coordinates)).energies
# get per atom energy
w, ref_atomic_energies = model.atomic_energies((species, coordinates))
w, torchani_atomic_energies = model.atomic_energies((species, coordinates))

# compare to reference energy
assert torch.allclose(
ref_atomic_energies,
torchani_atomic_energies,
torch.tensor(
[
[-38.0841, -0.5797, -0.5898, -0.6034, -0.6027],
Expand All @@ -129,15 +142,23 @@ def test_ani():
# NOTE: this is in Hartree
reference_ase = torch.tensor(
[
[0.0052, 0.0181, 0.0080, -0.0055, -0.0048],
[0.0052, 0.0181, 0.0080, -0.0055, -0.0048],
[
-38.08933878049795,
0.5978583943827134,
0.5978583943827134,
0.5978583943827134,
0.5978583943827134,
],
[
-38.08933878049795,
0.5978583943827134,
0.5978583943827134,
0.5978583943827134,
0.5978583943827134,
],
],
) - torch.tensor(
[
[-38.0841, -0.5797, -0.5898, -0.6034, -0.6027],
[-38.0841, -0.5797, -0.5898, -0.6034, -0.6027],
]
)

# ------------------------------------------ #
# setup modelforge potential
potential = setup_potential_for_test(
Expand All @@ -151,13 +172,65 @@ def test_ani():
potential.load_state_dict(torch.load(file_path))
# compare to original ani2x dataset
atomic_energies = potential(mf_input)["per_atom_energy"]
modelforge_atomic_energies = (
atomic_energies.flatten() + reference_ase.squeeze(0).flatten()
)

print(atomic_energies.flatten())
print(torchani_atomic_energies.flatten() - reference_ase.flatten())

print(modelforge_atomic_energies)
print(torchani_atomic_energies.flatten())

assert torch.allclose(
atomic_energies.flatten() - reference_ase.flatten(),
ref_atomic_energies.flatten(),
modelforge_atomic_energies,
torchani_atomic_energies.flatten(),
rtol=1e-3,
)


def test_ani_against_torchani_reference():
import torch

# get input
species, coordinates, device, mf_input = setup_two_methanes()

# ------------------------------------------ #
# setup modelforge potential
potential = setup_potential_for_test(
use="training",
potential_seed=42,
potential_name="ani2x",
jit=False,
local_cache_dir=str(prep_temp_dir),
)
# load the original ani2x parameter set
potential.load_state_dict(torch.load(file_path))
# compare to original ani2x dataset
atomic_energies = potential(mf_input)["per_atom_energy"]

assert torch.allclose(
atomic_energies,
torch.tensor(
[
[0.0052],
[0.0181],
[0.0080],
[-0.0055],
[-0.0048],
[0.0052],
[0.0181],
[0.0080],
[-0.0055],
[-0.0048],
]
),
rtol=1e-2,
) # that's the atomic energies for the two methane molecules obtained with torchani

a = 7


@pytest.mark.parametrize("mode", ["inference", "training"])
def test_forward_and_backward(mode):
# Test modelforge ANI implementation
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_train_with_lightning(loss, potential_name, dataset_name, prep_temp_dir)
# train potential
get_trainer(config).train_potential().save_checkpoint("test.chp") # save checkpoint
# continue training from checkpoint
get_trainer(config).train_potential()
# get_trainer(config).train_potential()


def test_train_from_single_toml_file(prep_temp_dir):
Expand Down
4 changes: 4 additions & 0 deletions modelforge/train/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ def forward(
predict_target[f"{prop_}_true"],
batch,
)
# check that none of the tensors are NaN
if torch.isnan(prop_loss).any():
raise ValueError(f"NaN values detected in {prop} loss.")

# Accumulate weighted per-sample losses
weighted_loss = self.weights_scheduling[prop][epoch_idx] * prop_loss

Expand Down
3 changes: 2 additions & 1 deletion modelforge/train/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def ensure_logger_configuration(self) -> "ExperimentLogger":
shift_center_of_mass_to_origin: bool
batch_size: int
lr: float
lr_scheduler: Optional[SchedulerConfig] = None # Use the Union type here
plot_frequency: int = 5 # how often to log regression and error histograms
lr_scheduler: Optional[SchedulerConfig] = None
loss_parameter: LossParameter
early_stopping: Optional[EarlyStopping] = None
splitting_strategy: SplittingStrategy
Expand Down
Loading

0 comments on commit 2a6b208

Please sign in to comment.