diff --git a/src/fairchem/core/common/registry.py b/src/fairchem/core/common/registry.py index 2912c7c50..59a83dd88 100644 --- a/src/fairchem/core/common/registry.py +++ b/src/fairchem/core/common/registry.py @@ -61,6 +61,7 @@ class Registry: # Mappings to respective classes. "task_name_mapping": {}, "dataset_name_mapping": {}, + "loss_name_mapping": {}, "model_name_mapping": {}, "logger_name_mapping": {}, "trainer_name_mapping": {}, @@ -109,6 +110,35 @@ def wrap(func: Callable[..., R]) -> Callable[..., R]: return wrap + @classmethod + def register_loss(cls, name): + r"""Register a loss to registry with key 'name' + + Args: + name: Key with which the loss will be registered. + + Usage:: + + from fairchem.core.common.registry import registry + from torch import nn + + @registry.register_loss("mae") + class MAELoss(nn.Module): + ... + + """ + + def wrap(func): + from torch import nn + + assert issubclass( + func, nn.Module + ), "All loss must inherit torch.nn.Module class" + cls.mapping["loss_name_mapping"][name] = func + return func + + return wrap + @classmethod def register_model(cls, name: str): r"""Register a model to registry with key 'name' @@ -255,6 +285,10 @@ def get_task_class(cls, name: str): def get_dataset_class(cls, name: str): return cls.get_class(name, "dataset_name_mapping") + @classmethod + def get_loss_class(cls, name): + return cls.get_class(name, "loss_name_mapping") + @classmethod def get_model_class(cls, name: str): return cls.get_class(name, "model_name_mapping") diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index a31983e93..24a3d7a44 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -43,7 +43,6 @@ import fairchem.core from fairchem.core.common.registry import registry -from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss if TYPE_CHECKING: from collections.abc import Mapping @@ -1433,21 +1432,6 @@ def update_config(base_config): return config -def get_loss_module(loss_name): - if loss_name in ["l1", "mae"]: - loss_fn = nn.L1Loss() - elif loss_name == "mse": - loss_fn = nn.MSELoss() - elif loss_name == "l2mae": - loss_fn = L2MAELoss() - elif loss_name == "atomwisel2": - loss_fn = AtomwiseL2Loss() - else: - raise NotImplementedError(f"Unknown loss function name: {loss_name}") - - return loss_fn - - def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module: if not os.path.isfile(checkpoint_path): raise FileNotFoundError( diff --git a/src/fairchem/core/modules/evaluator.py b/src/fairchem/core/modules/evaluator.py index c84021eef..b21ebad5d 100644 --- a/src/fairchem/core/modules/evaluator.py +++ b/src/fairchem/core/modules/evaluator.py @@ -7,7 +7,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar +from functools import wraps +from typing import TYPE_CHECKING, Callable, ClassVar import numpy as np import torch @@ -34,7 +35,7 @@ with the relevant metrics computed. """ -NONE = slice(None) +NONE_SLICE = slice(None) class Evaluator: @@ -88,10 +89,9 @@ def eval( self, prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], - prev_metrics=None, + prev_metrics: dict | None = None, ): - if prev_metrics is None: - prev_metrics = {} + prev_metrics = prev_metrics or {} metrics = prev_metrics for target_property in self.target_metrics: @@ -130,10 +130,90 @@ def update(self, key, stat, metrics): return metrics +def metrics_dict(metric_fun: Callable) -> Callable: + """Wrap up the return of a metrics function""" + + @wraps(metric_fun) + def wrapped_metrics( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + error = metric_fun(prediction, target, key, **kwargs) + return { + "metric": torch.mean(error).item(), + "total": torch.sum(error).item(), + "numel": error.numel(), + } + + return wrapped_metrics + + +@metrics_dict +def cosine_similarity( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, +): + # cast to float 32 to avoid 0/nan issues in fp16 + # https://github.com/pytorch/pytorch/issues/69512 + return torch.cosine_similarity(prediction[key].float(), target[key].float()) + + +@metrics_dict +def mae( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, +) -> torch.Tensor: + return torch.abs(target[key] - prediction[key]) + + +@metrics_dict +def mse( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, +) -> torch.Tensor: + return (target[key] - prediction[key]) ** 2 + + +@metrics_dict +def per_atom_mae( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, +) -> torch.Tensor: + return torch.abs(target[key] - prediction[key]) / target["natoms"].unsqueeze(1) + + +@metrics_dict +def per_atom_mse( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, +) -> torch.Tensor: + return ((target[key] - prediction[key]) / target["natoms"].unsqueeze(1)) ** 2 + + +@metrics_dict +def magnitude_error( + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + key: Hashable = NONE_SLICE, + p: int = 2, +) -> torch.Tensor: + assert prediction[key].shape[1] > 1 + return torch.abs( + torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1) + ) + + def forcesx_mae( prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], - key: Hashable = NONE, + key: Hashable = NONE_SLICE, ): return mae(prediction["forces"][:, 0], target["forces"][:, 0]) @@ -141,7 +221,7 @@ def forcesx_mae( def forcesx_mse( prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], - key: Hashable = NONE, + key: Hashable = NONE_SLICE, ): return mse(prediction["forces"][:, 0], target["forces"][:, 0]) @@ -289,57 +369,12 @@ def min_diff( return np.matmul(fractional, cell) -def cosine_similarity( +def rmse( prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], - key: Hashable = NONE, -): - # cast to float 32 to avoid 0/nan issues in fp16 - # https://github.com/pytorch/pytorch/issues/69512 - error = torch.cosine_similarity(prediction[key].float(), target[key].float()) - return { - "metric": torch.mean(error).item(), - "total": torch.sum(error).item(), - "numel": error.numel(), - } - - -def mae( - prediction: dict[str, torch.Tensor], - target: dict[str, torch.Tensor], - key: Hashable = NONE, -) -> dict[str, float | int]: - error = torch.abs(target[key] - prediction[key]) - return { - "metric": torch.mean(error).item(), - "total": torch.sum(error).item(), - "numel": error.numel(), - } - - -def mse( - prediction: dict[str, torch.Tensor], - target: dict[str, torch.Tensor], - key: Hashable = NONE, -) -> dict[str, float | int]: - error = (target[key] - prediction[key]) ** 2 - return { - "metric": torch.mean(error).item(), - "total": torch.sum(error).item(), - "numel": error.numel(), - } - - -def magnitude_error( - prediction: dict[str, torch.Tensor], - target: dict[str, torch.Tensor], - key: Hashable = NONE, - p: int = 2, + key: Hashable = None, ) -> dict[str, float | int]: - assert prediction[key].shape[1] > 1 - error = torch.abs( - torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1) - ) + error = torch.sqrt(((target[key] - prediction[key]) ** 2).sum(dim=-1)) return { "metric": torch.mean(error).item(), "total": torch.sum(error).item(), diff --git a/src/fairchem/core/modules/loss.py b/src/fairchem/core/modules/loss.py index b737e79e6..02754ead6 100644 --- a/src/fairchem/core/modules/loss.py +++ b/src/fairchem/core/modules/loss.py @@ -1,77 +1,172 @@ from __future__ import annotations import logging +from typing import Literal import torch from torch import nn from fairchem.core.common import distutils +from fairchem.core.common.registry import registry -class L2MAELoss(nn.Module): - def __init__(self, reduction: str = "mean") -> None: +@registry.register_loss("mae") +class MAELoss(nn.Module): + def __init__(self) -> None: super().__init__() - self.reduction = reduction - assert reduction in ["mean", "sum"] + self.loss = nn.L1Loss() + # reduction should be none as it is handled in DDPLoss + self.loss.reduction = "none" - def forward(self, input: torch.Tensor, target: torch.Tensor): - dists = torch.norm(input - target, p=2, dim=-1) - if self.reduction == "mean": - return torch.mean(dists) - elif self.reduction == "sum": - return torch.sum(dists) - - return dists + def forward( + self, pred: torch.Tensor, target: torch.Tensor, natoms: torch.Tensor + ) -> torch.Tensor: + return self.loss(pred, target) -class AtomwiseL2Loss(nn.Module): - def __init__(self, reduction: str = "mean") -> None: +@registry.register_loss("mse") +class MSELoss(nn.Module): + def __init__(self) -> None: super().__init__() - self.reduction = reduction - assert reduction in ["mean", "sum"] + self.loss = nn.MSELoss() + # reduction should be none as it is handled in DDPLoss + self.loss.reduction = "none" def forward( - self, - input: torch.Tensor, - target: torch.Tensor, - natoms: torch.Tensor, - ): - assert natoms.shape[0] == input.shape[0] == target.shape[0] - assert len(natoms.shape) == 1 # (nAtoms, ) + self, pred: torch.Tensor, target: torch.Tensor, natoms: torch.Tensor + ) -> torch.Tensor: + return self.loss(pred, target) - dists = torch.norm(input - target, p=2, dim=-1) - loss = natoms * dists - if self.reduction == "mean": - return torch.mean(loss) - elif self.reduction == "sum": - return torch.sum(loss) - return None +@registry.register_loss("per_atom_mae") +class PerAtomMAELoss(nn.Module): + """ + Simply divide a loss by the number of atoms/nodes in the graph. + Current this loss is intened to used with scalar values, not vectors or higher tensors. + """ + + def __init__(self) -> None: + super().__init__() + self.loss = nn.L1Loss() + # reduction should be none as it is handled in DDPLoss + self.loss.reduction = "none" + + def forward( + self, pred: torch.Tensor, target: torch.Tensor, natoms: torch.Tensor + ) -> torch.Tensor: + _natoms = torch.reshape(natoms, target.shape) + # check if target is a scalar + assert target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1) + # check per_atom shape + assert (target / _natoms).shape == target.shape + return self.loss(pred / _natoms, target / _natoms) + + +@registry.register_loss("l2norm") +@registry.register_loss("l2mae") +class L2NormLoss(nn.Module): + """ + Currently this loss is intened to used with vectors. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, pred: torch.Tensor, target: torch.Tensor, natoms: torch.Tensor + ) -> torch.Tensor: + assert target.dim() == 2 + assert target.shape[1] != 1 + return torch.linalg.vector_norm(pred - target, ord=2, dim=-1) class DDPLoss(nn.Module): + """ + This class is a wrapper around a loss function that does a few things + like handle nans and importantly ensures the reduction is done + correctly for DDP. The main issue is that DDP averages gradients + over replicas — this only works out of the box if the dimension + you are averaging over is completely consistent across all replicas. + In our case, that is not true for the number of atoms per batch and + there are edge cases when the batch size differs between replicas + e.g. if the dataset size is not divisible by the batch_size. + + Scalars are relatively straightforward to handle, but vectors and higher tensors + are a bit trickier. Below are two examples of forces. + + Forces input: [Nx3] target: [Nx3] + Forces are a vector of length 3 (x,y,z) for each atom. + Number of atoms per batch (N) is different for each DDP replica. + + MSE example: + #### Local loss computation #### + local_loss = MSELoss(input, target) -> [Nx3] + num_samples = local_loss.numel() -> [Nx3] + local_loss = sum(local_loss [Nx3]) -> [1] sum reduces the loss to a scalar + global_samples = all_reduce(num_samples) -> [N0x3 + N1x3 + N2x3 + ...] = [1] where N0 is the number of atoms on replica 0 + local_loss = local_loss * world_size / global_samples -> [1] + #### Global loss computation #### + global_loss = sum(local_loss / world_size) -> [1] + == sum(local_loss / global_samples) # this is the desired corrected mean + + Norm example: + #### Local loss computation #### + local_loss = L2MAELoss(input, target) -> [N] + num_samples = local_loss.numel() -> [N] + local_loss = sum(local_loss [N]) -> [1] sum reduces the loss to a scalar + global_samples = all_reduce(num_samples) -> [N0 + N1 + N2 + ...] = [1] where N0 is the number of atoms on replica 0 + local_loss = local_loss * world_size / global_samples -> [1] + #### Global loss computation #### + global_loss = sum(local_loss / world_size) -> [1] + == sum(local_loss / global_samples) # this is the desired corrected mean + """ + def __init__( - self, loss_fn, loss_name: str = "mae", reduction: str = "mean" + self, + loss_name, + reduction: Literal["mean", "sum"], ) -> None: super().__init__() - self.loss_fn = loss_fn - self.loss_name = loss_name - self.reduction = reduction - assert reduction in ["mean", "mean_all", "sum"] - - # for forces, we want to sum over xyz errors and average over batches/atoms (mean) - # for other metrics, we want to average over all axes (mean_all) or leave as a sum (sum) - if reduction == "mean_all": - self.loss_fn.reduction = "mean" + self.loss_fn = registry.get_loss_class(loss_name)() + # default reduction is mean + self.reduction = reduction if reduction is not None else "mean" + self.reduction_map = { + "mean": self.mean, + "sum": self.sum, + } + assert self.reduction in list( + self.reduction_map.keys() + ), "Reduction must be one of: 'mean', 'sum'" + + def sum(self, input, loss, natoms): + # this sum will reduce the loss down to a single scalar + return torch.sum(loss) + + def _ddp_mean(self, num_samples, loss): + global_samples = distutils.all_reduce(num_samples, device=loss.device) + # Multiply by world size since gradients are averaged across DDP replicas + # warning this is probably incorrect for any model parallel approach + return loss * distutils.get_world_size() / global_samples + + def mean(self, input, loss, natoms): + # total elements to take the mean over + # could be batch_size, num_atoms, num_atomsx3, etc + num_samples = loss.numel() + # this sum will reduce the loss down from num_sample -> 1 + loss = self.sum(input, loss, natoms) + return self._ddp_mean(num_samples, loss) + + def _reduction(self, input, loss, natoms): + if self.reduction in self.reduction_map: + return self.reduction_map[self.reduction](input, loss, natoms) else: - self.loss_fn.reduction = "sum" + raise ValueError("Reduction must be one of: 'mean', 'sum'") def forward( self, input: torch.Tensor, target: torch.Tensor, - natoms: torch.Tensor | None = None, - batch_size: int | None = None, + natoms: torch.Tensor, ): # ensure torch doesn't do any unwanted broadcasting assert ( @@ -84,19 +179,5 @@ def forward( logging.warning("Found nans while computing loss") input = torch.nan_to_num(input, nan=0.0) - if self.loss_name.startswith("atomwise"): - loss = self.loss_fn(input, target, natoms) - else: - loss = self.loss_fn(input, target) - - if self.reduction == "mean": - num_samples = ( - batch_size if self.loss_name.startswith("atomwise") else input.shape[0] - ) - num_samples = distutils.all_reduce(num_samples, device=input.device) - # Multiply by world size since gradients are averaged - # across DDP replicas - return loss * distutils.get_world_size() / num_samples - else: - # if reduction is sum or mean over all axes, no other operations are needed - return loss + loss = self.loss_fn(input, target, natoms) + return self._reduction(input, loss, natoms) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 0ead76b2c..1a7d0a7d8 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -22,7 +22,6 @@ import numpy as np import numpy.typing as npt import torch -import torch.nn as nn import yaml from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader @@ -40,7 +39,6 @@ from fairchem.core.common.typing import none_throws from fairchem.core.common.utils import ( get_commit_hash, - get_loss_module, load_state_dict, match_state_dict, save_checkpoint, @@ -671,18 +669,17 @@ def load_loss(self) -> None: self.loss_functions = [] for _idx, loss in enumerate(self.config["loss_functions"]): for target in loss: - loss_name = loss[target].get("fn", "mae") - coefficient = loss[target].get("coefficient", 1) - loss_reduction = loss[target].get("reduction", "mean") - - ### if torch module name provided, use that directly - if hasattr(nn, loss_name): - loss_fn = getattr(nn, loss_name)() - ### otherwise, retrieve the correct module based off old naming - else: - loss_fn = get_loss_module(loss_name) - - loss_fn = DDPLoss(loss_fn, loss_name, loss_reduction) + assert ( + "fn" in loss[target] + ), f"'fn' is not defined in the {target} loss config {loss[target]}." + loss_name = loss[target].get("fn") + assert ( + "coefficient" in loss[target] + ), f"'coefficient' is not defined in the {target} loss config {loss[target]}." + coefficient = loss[target].get("coefficient") + loss_reduction = loss[target].get("reduction") + + loss_fn = DDPLoss(loss_name, reduction=loss_reduction) self.loss_functions.append( (target, {"fn": loss_fn, "coefficient": coefficient}) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 218b1d208..9a13faed6 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -349,8 +349,7 @@ def _compute_loss(self, out, batch) -> torch.Tensor: * loss_info["fn"]( pred, target, - natoms=natoms, - batch_size=batch_size, + natoms=batch.natoms, ) ) @@ -406,6 +405,15 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): targets["natoms"] = natoms out["natoms"] = natoms + # add all other tensor properties too, but filter out the ones that are changed above + for key in filter( + lambda k: k not in [*list(self.output_targets.keys()), "natoms"] + and isinstance(batch[k], torch.Tensor), + batch.keys(), + ): + targets[key] = batch[key].to(self.device) + out[key] = targets[key] + return evaluator.eval(out, targets, prev_metrics=metrics) # Takes in a new data source and generates predictions on it. diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index 81b945b8b..042303eb1 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -15,9 +15,7 @@ ("escn_hydra"), ], ) -def test_smoke_s2efs_predict( - model_name, configs, dummy_binary_dataset_path, tmpdir -): +def test_smoke_s2efs_predict(model_name, configs, dummy_binary_dataset_path, tmpdir): # train an s2ef model just to have one input_yaml = configs[model_name] train_rundir = tmpdir / "train" @@ -42,8 +40,10 @@ def test_smoke_s2efs_predict( {"forces": {"fn": "l2mae", "coefficient": 100}}, {"stress": {"fn": "mae", "coefficient": 100}}, ], - "outputs": {"stress": {"level": "system", "irrep_dim": 2, "property": "stress"}}, - "evaluation_metrics": {"metrics": {"stress": ["mae"]}}, + "outputs": { + "stress": {"level": "system", "irrep_dim": 2, "property": "stress"} + }, + "evaluation_metrics": {"metrics": {"stress": ["mae", "per_atom_mae"]}}, "dataset": { "train": { "src": str(dummy_binary_dataset_path), diff --git a/tests/core/modules/test_loss.py b/tests/core/modules/test_loss.py new file mode 100644 index 000000000..fe1f9c8fa --- /dev/null +++ b/tests/core/modules/test_loss.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import pytest +import torch +from torch import nn + +from fairchem.core.common import distutils +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.modules.loss import ( + DDPLoss, + L2NormLoss, + MAELoss, + MSELoss, + PerAtomMAELoss, +) + + +@pytest.fixture() +def energy(): + # batch size = 4 + pred = torch.rand([4, 1]) + target = torch.rand([4, 1]) + return pred, target + + +@pytest.fixture() +def forces(): + # batch size = 4 + # total atoms = 100 + pred = torch.rand(100, 3) + target = torch.rand(100, 3) + return pred, target + + +@pytest.fixture() +def anisotropic_stress(): + # batch size = 4 + pred = torch.rand([4, 5]) + target = torch.rand([4, 5]) + return pred, target + + +@pytest.fixture() +def natoms(): + # batch size = 4 + # total atoms = 100 + return torch.tensor([25, 34, 21, 20]) + + +def test_mae(energy, forces, natoms): + loss = MAELoss() + ref_loss = nn.L1Loss(reduction="none") + pred, target = energy + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + pred, target = forces + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + + +def test_mse(energy, forces): + loss = MSELoss() + ref_loss = nn.MSELoss(reduction="none") + pred, target = energy + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + pred, target = forces + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + + +def test_per_atom_mae(energy, natoms): + loss = PerAtomMAELoss() + ref_loss = nn.L1Loss(reduction="none") + pred, target = energy + _natoms = torch.reshape(natoms, target.shape) + assert target.shape == (target / _natoms).shape + assert torch.allclose( + loss(pred, target, natoms), ref_loss(pred / _natoms, target / _natoms) + ) + + +def test_l2norm(forces, natoms): + loss = L2NormLoss() + pred, target = forces + ref_norm = torch.linalg.vector_norm(pred - target, ord=2, dim=-1) + assert torch.allclose(loss(pred, target, natoms), ref_norm) + + +def test_energy_mae_reduction(energy, natoms): + # this is testing on a single process i.e. world_size=1 + # mean reduction + loss = DDPLoss(loss_name="mae", reduction="mean") + ref_loss = nn.L1Loss(reduction="mean") + pred, target = energy + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + # sum reduction + loss = DDPLoss(loss_name="mae", reduction="sum") + ref_loss = nn.L1Loss(reduction="sum") + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + + +def test_stress_mae_reduction(anisotropic_stress, natoms): + # this is testing on a single process i.e. world_size=1 + # mean reduction + loss = DDPLoss(loss_name="mae", reduction="mean") + ref_loss = nn.L1Loss(reduction="mean") + pred, target = anisotropic_stress + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + # sum reduction + loss = DDPLoss(loss_name="mae", reduction="sum") + ref_loss = nn.L1Loss(reduction="sum") + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + + +def test_l2norm_reduction(forces, natoms): + # this is testing on a single process i.e. world_size=1 + # mean reduction + loss = DDPLoss(loss_name="l2norm", reduction="mean") + pred, target = forces + ref_norm = torch.linalg.vector_norm(pred - target, ord=2, dim=-1) + ref_loss = ref_norm.mean() + assert torch.allclose(loss(pred, target, natoms), ref_loss) + # sum reduction + loss = DDPLoss(loss_name="l2norm", reduction="sum") + ref_loss = ref_norm.sum() + assert torch.allclose(loss(pred, target, natoms), ref_loss) + + +def test_mse_reduction(forces, natoms): + # this is testing on a single process i.e. world_size=1 + # mean reduction + loss = DDPLoss(loss_name="mse", reduction="mean") + ref_loss = nn.MSELoss(reduction="mean") + pred, target = forces + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + # sum reduction + loss = DDPLoss(loss_name="mse", reduction="sum") + ref_loss = nn.MSELoss(reduction="sum") + assert torch.allclose(loss(pred, target, natoms), ref_loss(pred, target)) + + +def split_batch_for_ddp( + task: str, pred: torch.Tensor, target: torch.Tensor, natoms: torch.Tensor +): + if task == "energy": + return list(torch.split(pred, 1)), list(torch.split(target, 1)) + elif task == "forces": + split_shape = natoms.tolist() + return list(torch.split(pred, split_shape, dim=0)), list( + torch.split(target, split_shape, dim=0) + ) + else: + raise ValueError(f"Invalid task: {task}") + + +def run_ddp_loss(pred, target, natoms, loss_name, reduction): + loss = DDPLoss(loss_name=loss_name, reduction=reduction) + local_rank = distutils.get_rank() + return loss(pred[int(local_rank)], target[int(local_rank)], natoms) + + +@pytest.fixture() +def world_size(): + # batch size = 4 + return 4 + + +def test_ddp_mae(energy, natoms, world_size): + pred, target = energy + ddp_pred, ddp_target = split_batch_for_ddp("energy", pred, target, natoms) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + output = spawn_multi_process( + config, + run_ddp_loss, + init_pg_and_rank_and_launch_test, + ddp_pred, + ddp_target, + natoms, + "mae", + "mean", + ) + # this mocks what ddp does when averaging gradients + ddp_loss = torch.sum(torch.tensor(output)) / float(world_size) + ref_loss = nn.L1Loss(reduction="mean") + assert torch.allclose(ddp_loss, ref_loss(pred, target)) + + +def test_ddp_l2norm(forces, natoms, world_size): + pred, target = forces + ddp_pred, ddp_target = split_batch_for_ddp("forces", pred, target, natoms) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + output = spawn_multi_process( + config, + run_ddp_loss, + init_pg_and_rank_and_launch_test, + ddp_pred, + ddp_target, + natoms, + "l2norm", + "mean", + ) + # this mocks what ddp does when averaging gradients + ddp_loss = torch.sum(torch.tensor(output)) / float(world_size) + ref_norm = torch.linalg.vector_norm(pred - target, ord=2, dim=-1) + ref_loss = ref_norm.mean() + assert torch.allclose(ddp_loss, ref_loss) + + +def test_ddp_mse(forces, natoms, world_size): + pred, target = forces + ddp_pred, ddp_target = split_batch_for_ddp("forces", pred, target, natoms) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + output = spawn_multi_process( + config, + run_ddp_loss, + init_pg_and_rank_and_launch_test, + ddp_pred, + ddp_target, + natoms, + "mse", + "mean", + ) + # this mocks what ddp does when averaging gradients + ddp_loss = torch.sum(torch.tensor(output)) / float(world_size) + ref_loss = nn.MSELoss(reduction="mean") + assert torch.allclose(ddp_loss, ref_loss(pred, target))