Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 16, 2024
1 parent 62f08ab commit edc2d11
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,18 @@ def __init__(
heads: dict | None = None,
finetune_config: dict | None = None,
otf_graph: bool = True,
pass_through_head_outputs: bool = False,
):
super().__init__()
self.otf_graph = otf_graph
self.pass_through_head_outputs = pass_through_head_outputs

# if finetune_config is provided, then attempt to load the model from the given finetune checkpoint
starting_model = None
if finetune_config is not None:
starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config['starting_checkpoint'])
starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"])
assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!"

if backbone is not None:
backbone = copy.deepcopy(backbone)
backbone_model_name = backbone.pop("model")
Expand Down Expand Up @@ -289,7 +291,10 @@ def forward(self, data: Batch):
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
out[k] = self.output_heads[k](data, emb)
if self.pass_through_head_outputs:
out.update(self.output_heads[k](data, emb))
else:
out[k] = self.output_heads[k](data, emb)
return out


1 change: 1 addition & 0 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
update_config,
)
from fairchem.core.datasets.base_dataset import create_dataset

# from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig
from fairchem.core.modules.evaluator import Evaluator
from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage
Expand Down
1 change: 1 addition & 0 deletions tests/core/models/test_configs/test_dpp_hydra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ model:
heads:
energy:
module: dimenetplusplus_energy_and_force_head
pass_through_head_outputs: True

# *** Important note ***
# The total number of gpus used for this run was 256.
Expand Down
1 change: 1 addition & 0 deletions tests/core/models/test_configs/test_gemnet_dt_hydra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ model:
forces:
module: gemnet_t_force_head


optim:
batch_size: 8
eval_batch_size: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ model:
heads:
energy_and_forces:
module: gemnet_t_energy_and_grad_force_head
pass_through_head_outputs: True

optim:
batch_size: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ model:
energy:
module: gemnet_oc_energy_and_grad_force_head
num_global_out_layers: 2
pass_through_head_outputs: True

optim:
batch_size: 5
Expand Down

0 comments on commit edc2d11

Please sign in to comment.