From 3c97c38c3f652a27659d3cbcbebcaf9b4c24041d Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 15 Aug 2024 22:54:08 +0000 Subject: [PATCH] make hydra compat with multitask --- src/fairchem/core/models/base.py | 4 ++-- src/fairchem/core/trainers/ocp_trainer.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 7380d036cd..08fdb3c583 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -264,6 +264,7 @@ def __init__( self.output_heads: dict[str, HeadInterface] = {} head_names_sorted = sorted(heads.keys()) + assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!" for head_name in head_names_sorted: head_config = heads[head_name] if "module" not in head_config: @@ -284,8 +285,7 @@ def forward(self, data: Batch): # Predict all output properties for all structures in the batch for now. out = {} for k in self.output_heads: - out.update(self.output_heads[k](data, emb)) - + out[k] = self.output_heads[k](data, emb) return out def get_backbone(self) -> BackboneInterface: diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 26269c6da4..cd39017fc2 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -251,9 +251,18 @@ def _forward(self, batch): for target_key in self.output_targets: ### Target property is a direct output of the model if target_key in out: - pred = out[target_key] - ## Target property is a derived output of the model. Construct the - ## parent property + if isinstance(out[target_key], torch.Tensor): + pred = out[target_key] + elif isinstance(out[target_key], dict): + # if output is a nested dictionary (in the case of hydra models), we attempt to retrieve it using the property name + # ie: "output_head_name.property" + assert "property" in self.output_targets[target_key], \ + f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}" + property = self.output_targets[target_key]["property"] + pred = out[target_key][property] + + ## TODO: deprecate the following logic? + ## Otherwise, assume target property is a derived output of the model. Construct the parent property else: _max_rank = 0 for subtarget_key in self.output_targets[target_key]["decomposition"]: