Skip to content

Commit

Permalink
Updated loss and eval metrics (#896)
Browse files Browse the repository at this point in the history
* add density metrics

* update trainer & loss

* interleave atoms in loss

* fix call to keys

* add rmse to evaluation metrics

* fix linting.

* updated loss module with tests

* more tests and updated eval names

* failing test update

* renaming norm loss p2 -> l2

* minor changes based on comments

* added inline comment

---------

Co-authored-by: lbluque <[email protected]>
Co-authored-by: Xiang Fu <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent a2a53ee commit fbec2d3
Show file tree
Hide file tree
Showing 8 changed files with 523 additions and 151 deletions.
34 changes: 34 additions & 0 deletions src/fairchem/core/common/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Registry:
# Mappings to respective classes.
"task_name_mapping": {},
"dataset_name_mapping": {},
"loss_name_mapping": {},
"model_name_mapping": {},
"logger_name_mapping": {},
"trainer_name_mapping": {},
Expand Down Expand Up @@ -109,6 +110,35 @@ def wrap(func: Callable[..., R]) -> Callable[..., R]:

return wrap

@classmethod
def register_loss(cls, name):
r"""Register a loss to registry with key 'name'
Args:
name: Key with which the loss will be registered.
Usage::
from fairchem.core.common.registry import registry
from torch import nn
@registry.register_loss("mae")
class MAELoss(nn.Module):
...
"""

def wrap(func):
from torch import nn

assert issubclass(
func, nn.Module
), "All loss must inherit torch.nn.Module class"
cls.mapping["loss_name_mapping"][name] = func
return func

return wrap

@classmethod
def register_model(cls, name: str):
r"""Register a model to registry with key 'name'
Expand Down Expand Up @@ -255,6 +285,10 @@ def get_task_class(cls, name: str):
def get_dataset_class(cls, name: str):
return cls.get_class(name, "dataset_name_mapping")

@classmethod
def get_loss_class(cls, name):
return cls.get_class(name, "loss_name_mapping")

@classmethod
def get_model_class(cls, name: str):
return cls.get_class(name, "model_name_mapping")
Expand Down
16 changes: 0 additions & 16 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

import fairchem.core
from fairchem.core.common.registry import registry
from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -1433,21 +1432,6 @@ def update_config(base_config):
return config


def get_loss_module(loss_name):
if loss_name in ["l1", "mae"]:
loss_fn = nn.L1Loss()
elif loss_name == "mse":
loss_fn = nn.MSELoss()
elif loss_name == "l2mae":
loss_fn = L2MAELoss()
elif loss_name == "atomwisel2":
loss_fn = AtomwiseL2Loss()
else:
raise NotImplementedError(f"Unknown loss function name: {loss_name}")

return loss_fn


def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
Expand Down
145 changes: 90 additions & 55 deletions src/fairchem/core/modules/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from functools import wraps
from typing import TYPE_CHECKING, Callable, ClassVar

import numpy as np
import torch
Expand All @@ -34,7 +35,7 @@
with the relevant metrics computed.
"""

NONE = slice(None)
NONE_SLICE = slice(None)


class Evaluator:
Expand Down Expand Up @@ -88,10 +89,9 @@ def eval(
self,
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
prev_metrics=None,
prev_metrics: dict | None = None,
):
if prev_metrics is None:
prev_metrics = {}
prev_metrics = prev_metrics or {}
metrics = prev_metrics

for target_property in self.target_metrics:
Expand Down Expand Up @@ -130,18 +130,98 @@ def update(self, key, stat, metrics):
return metrics


def metrics_dict(metric_fun: Callable) -> Callable:
"""Wrap up the return of a metrics function"""

@wraps(metric_fun)
def wrapped_metrics(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = None,
**kwargs,
) -> dict[str, torch.Tensor]:
error = metric_fun(prediction, target, key, **kwargs)
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}

return wrapped_metrics


@metrics_dict
def cosine_similarity(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
):
# cast to float 32 to avoid 0/nan issues in fp16
# https://github.com/pytorch/pytorch/issues/69512
return torch.cosine_similarity(prediction[key].float(), target[key].float())


@metrics_dict
def mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return torch.abs(target[key] - prediction[key])


@metrics_dict
def mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return (target[key] - prediction[key]) ** 2


@metrics_dict
def per_atom_mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return torch.abs(target[key] - prediction[key]) / target["natoms"].unsqueeze(1)


@metrics_dict
def per_atom_mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return ((target[key] - prediction[key]) / target["natoms"].unsqueeze(1)) ** 2


@metrics_dict
def magnitude_error(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
p: int = 2,
) -> torch.Tensor:
assert prediction[key].shape[1] > 1
return torch.abs(
torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1)
)


def forcesx_mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
key: Hashable = NONE_SLICE,
):
return mae(prediction["forces"][:, 0], target["forces"][:, 0])


def forcesx_mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
key: Hashable = NONE_SLICE,
):
return mse(prediction["forces"][:, 0], target["forces"][:, 0])

Expand Down Expand Up @@ -289,57 +369,12 @@ def min_diff(
return np.matmul(fractional, cell)


def cosine_similarity(
def rmse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
):
# cast to float 32 to avoid 0/nan issues in fp16
# https://github.com/pytorch/pytorch/issues/69512
error = torch.cosine_similarity(prediction[key].float(), target[key].float())
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
) -> dict[str, float | int]:
error = torch.abs(target[key] - prediction[key])
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
) -> dict[str, float | int]:
error = (target[key] - prediction[key]) ** 2
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def magnitude_error(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
p: int = 2,
key: Hashable = None,
) -> dict[str, float | int]:
assert prediction[key].shape[1] > 1
error = torch.abs(
torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1)
)
error = torch.sqrt(((target[key] - prediction[key]) ** 2).sum(dim=-1))
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
Expand Down
Loading

0 comments on commit fbec2d3

Please sign in to comment.