Skip to content

Commit

Permalink
make hydra compat with multitask
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 15, 2024
1 parent ef2a4bc commit 3c97c38
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.output_heads: dict[str, HeadInterface] = {}

head_names_sorted = sorted(heads.keys())
assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!"
for head_name in head_names_sorted:
head_config = heads[head_name]
if "module" not in head_config:
Expand All @@ -284,8 +285,7 @@ def forward(self, data: Batch):
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))

out[k] = self.output_heads[k](data, emb)
return out

def get_backbone(self) -> BackboneInterface:
Expand Down
15 changes: 12 additions & 3 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,18 @@ def _forward(self, batch):
for target_key in self.output_targets:
### Target property is a direct output of the model
if target_key in out:
pred = out[target_key]
## Target property is a derived output of the model. Construct the
## parent property
if isinstance(out[target_key], torch.Tensor):
pred = out[target_key]
elif isinstance(out[target_key], dict):
# if output is a nested dictionary (in the case of hydra models), we attempt to retrieve it using the property name
# ie: "output_head_name.property"
assert "property" in self.output_targets[target_key], \
f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}"
property = self.output_targets[target_key]["property"]
pred = out[target_key][property]

## TODO: deprecate the following logic?
## Otherwise, assume target property is a derived output of the model. Construct the parent property
else:
_max_rank = 0
for subtarget_key in self.output_targets[target_key]["decomposition"]:
Expand Down

0 comments on commit 3c97c38

Please sign in to comment.