Skip to content

Commit

Permalink
fix trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Nov 9, 2024
1 parent 63d4eb0 commit 62659ff
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -155,9 +156,9 @@ def denoising_pos_eval(
evaluator: Evaluator,
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
denoising_targets: tuple[str],
prev_metrics: dict[str, torch.Tensor] | None = None,
denoising_pos_forward: bool = False,
denoising_targets=tuple[str],
):
"""
1. Overwrite the original Evaluator.eval() here: https://github.com/Open-Catalyst-Project/ocp/blob/5a7738f9aa80b1a9a7e0ca15e33938b4d2557edd/ocpmodels/modules/evaluator.py#L69-L81
Expand All @@ -168,12 +169,11 @@ def denoising_pos_eval(
return evaluator.eval(prediction, target, prev_metrics)

metrics = prev_metrics
for target_name in denoising_targets:
res = mae(prediction, target, target_name)
metrics = evaluator.update(f"denoising_{target_name}_mae", res, metrics)

for target in denoising_targets:
res = mae(prediction, target, target)
metrics = evaluator.update(f"denoising_{target}_mae", res, metrics)

if target.get("noise_mask", None) is None:
if target.get("noise_mask") is None:
# Only update`denoising_pos_mae` during denoising positions if not using partially corrupted structures
res = mae(prediction, target, "forces")
metrics = evaluator.update("denoising_pos_mae", res, metrics)
Expand Down Expand Up @@ -329,13 +329,12 @@ def __init__(
),
)
self.normalizers["denoising_pos_target"].to(self.device)
self.denoising_targets = None

def load_model(self) -> None:
super().load_model()
self.denoising_targets = tuple(
@cached_property
def denoising_targets(self):
return tuple(
head.output_name
for head in self.model.output_heads.values()
for head in self._unwrapped_model.output_heads.values()
if getattr(head, "use_denoising", False)
)

Expand Down Expand Up @@ -546,7 +545,6 @@ def _compute_loss(self, out, batch):
pred,
target,
natoms=natoms,
batch_size=batch_size,
)
)
else:
Expand All @@ -569,7 +567,6 @@ def _compute_loss(self, out, batch):
pred,
target,
natoms=natoms,
batch_size=batch_size,
)
)

Expand Down Expand Up @@ -611,7 +608,6 @@ def _compute_loss(self, out, batch):
pred,
target,
natoms=batch.natoms,
batch_size=batch_size,
)
)

Expand Down

0 comments on commit 62659ff

Please sign in to comment.