Skip to content

Commit

Permalink
Merge branch 'main' into rgao_fuse_hydras_0
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 16, 2024
2 parents 9293cd0 + 2078e48 commit 701ac63
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 18 deletions.
106 changes: 106 additions & 0 deletions configs/ocp_hydra_example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
trainer: equiformerv2_forces

dataset:
train:
format: lmdb
src: data/s2ef/train/
key_mapping:
y: energy
force: forces
transforms:
normalizer:
energy:
mean: -0.7554450631141663
stdev: 2.887317180633545
forces:
mean: 0
stdev: 2.887317180633545
val:
src: data/s2ef/val_id/

logger: wandb

outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 4
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

hide_eval_progressbar: False


model:
#The model definition here is using "hydra"
# The hydra model is really a container for a backbone model
# and a variable number of head modules
name: hydra
# Use a lightweight (4 layer) eSCN backbone
backbone:
model: escn_backbone
num_layers: 4
max_neighbors: 20
cutoff: 12.0
sphere_channels: 128
hidden_channels: 256
lmax_list: [6]
mmax_list: [2]
num_sphere_samples: 128
distance_function: "gaussian"
regress_forces: True
use_pbc: True
basis_width_scalar: 2.0
otf_graph: True
# Use a energy and direct forces head attached onto the eSCN backbone
# This will output both energy and direct forces for each input system
heads:
energy:
module: escn_energy_head
forces:
module: escn_force_head

optim:
batch_size: 8 # 6
eval_batch_size: 12 # 6
load_balancing: atoms
num_workers: 8
lr_initial: 0.0004 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96

optimizer: AdamW
optimizer_params:
weight_decay: 0.001
scheduler: LambdaLR
scheduler_params:
lambda_type: cosine
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01 #

max_epochs: 3
clip_grad_norm: 100
ema_decay: 0.999

eval_every: 10000
4 changes: 3 additions & 1 deletion src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ def _init_weights(self, m):
if self.weight_init == "normal":
std = 1 / math.sqrt(m.in_features)
torch.nn.init.normal_(m.weight, 0, std)
elif self.weight_init == "uniform":
self._uniform_init_linear_weights(m)

elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
Expand All @@ -647,7 +649,7 @@ def _uniform_init_rad_func_linear_weights(self, m):
m.apply(self._uniform_init_linear_weights)

def _uniform_init_linear_weights(self, m):
if isinstance(m, torch.nn.Linear):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
std = 1 / math.sqrt(m.in_features)
Expand Down
39 changes: 24 additions & 15 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def __init__(
"%j", self.config["slurm"]["job_id"]
)
if distutils.is_master():
add_timestamp_id_to_submission_pickle(self.config["slurm"]["folder"], self.config["slurm"]["job_id"], self.timestamp_id)
add_timestamp_id_to_submission_pickle(
self.config["slurm"]["folder"],
self.config["slurm"]["job_id"],
self.timestamp_id,
)

# Define datasets
if isinstance(dataset, list):
Expand Down Expand Up @@ -425,19 +429,23 @@ def load_references_and_normalizers(self):
elementref_config,
dataset=self.train_dataset,
seed=self.config["cmd"]["seed"],
checkpoint_dir=self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None,
checkpoint_dir=(
self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None
),
)

if norms_config is not None:
normalizers = load_normalizers_from_config(
norms_config,
dataset=self.train_dataset,
seed=self.config["cmd"]["seed"],
checkpoint_dir=self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None,
checkpoint_dir=(
self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None
),
element_references=elementrefs,
)

Expand Down Expand Up @@ -486,15 +494,15 @@ def load_task(self):
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
self.config["outputs"][target_name].get(
"train_on_free_atoms", True
)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down Expand Up @@ -550,13 +558,13 @@ def _unwrapped_model(self):
def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
) -> None:
map_location = torch.device("cpu") if self.cpu else self.device
if checkpoint is None:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
map_location = torch.device("cpu") if self.cpu else self.device
checkpoint = torch.load(checkpoint_path, map_location=map_location)

self.epoch = checkpoint.get("epoch", 0)
Expand Down Expand Up @@ -599,13 +607,14 @@ def load_checkpoint(
mkeys = self.normalizers[target_key].load_state_dict(
checkpoint["normalizers"][key]
)
self.normalizers[target_key].to(map_location)
assert len(mkeys.missing_keys) == 0
assert len(mkeys.unexpected_keys) == 0

for key, state_dict in checkpoint.get("elementrefs", {}).items():
elementrefs = LinearReferences(
max_num_elements=len(state_dict["element_references"]) - 1
)
).to(map_location)
mkeys = elementrefs.load_state_dict(state_dict)
self.elementrefs[key] = elementrefs
assert len(mkeys.missing_keys) == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/core/models/__snapshots__/test_equiformer_v2.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# ---
# name: TestEquiformerV2.test_gp.1
Approx(
array([0.12408741], dtype=float32),
array([-0.03269595], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -69,7 +69,7 @@
# ---
# name: TestEquiformerV2.test_gp.3
Approx(
array([ 1.4928658e-03, -7.4134972e-05, 2.9909210e-03], dtype=float32),
array([ 0.00208857, -0.00017979, -0.0028318 ], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand Down

0 comments on commit 701ac63

Please sign in to comment.