Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mask NaNs in training loss function #56

Merged
merged 15 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you!
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)

Expand Down
13 changes: 12 additions & 1 deletion src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
super().__init__(config, data_indices, statistics)

self.nan_locations = None
# weight imputed values wiht zero in loss calculation
self.loss_mask_training = None

def _validate_indices(self):
assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), (
Expand Down Expand Up @@ -109,12 +111,21 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()

# Initilialize mask once
# Initialize nan mask once
if self.nan_locations is None:
# The mask is only saved for the last two dimensions (grid, variable)
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
self.nan_locations = torch.isnan(x[idx].squeeze())

# Initialize training loss mask to weigh imputed values with zeroes once
self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
) # shape (grid, n_outputs)
# for all variables that are imputed and part of the model output, set the loss weight to zero
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
if idx_dst is not None:
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
Expand Down
29 changes: 29 additions & 0 deletions src/anemoi/models/preprocessing/remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,35 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:

return x_remapped

def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Remap the loss mask.

```
x : torch.Tensor
Loss mask
```
"""
# use indices at model output level
index = self.index_inference_backmapped_output
indices_remapped = self.index_inference_output
indices_keep = self.indices_keep_inference_output

# create new loss mask with target number of columns
mask_remapped = torch.zeros(
mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device
)

# copy loss mask for variables that are not remapped
mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep]

# remap loss mask for rest of variables
for idx_src, idx_dst in zip(indices_remapped, index):
if idx_dst is not None:
for ii in idx_dst:
mask_remapped[..., ii] = mask[..., idx_src]

return mask_remapped

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Convert and remap the output tensor.

Expand Down
20 changes: 20 additions & 0 deletions tests/preprocessing/test_preprocessor_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,26 @@ def test_mask_saving(imputer_fixture, data_fixture, request):
assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
)
def test_loss_nan_mask(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
x, _ = request.getfixturevalue(data_fixture)
expected = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) # only prognostic and diagnostic variables
imputer = request.getfixturevalue(imputer_fixture)
imputer.transform(x)
assert torch.allclose(
imputer.loss_mask_training, expected
), "Transform does not calculate NaN-mask for loss function scaling correctly."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
Expand Down
40 changes: 40 additions & 0 deletions tests/preprocessing/test_preprocessor_remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
# nor does it submit to any jurisdiction.


import numpy as np
import pytest
import torch
from omegaconf import DictConfig

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing.imputer import InputImputer
from anemoi.models.preprocessing.remapper import Remapper


Expand Down Expand Up @@ -41,6 +43,34 @@ def input_remapper():
return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics)


@pytest.fixture()
def input_imputer():
config = DictConfig(
{
"diagnostics": {"log": {"code": {"level": "DEBUG"}}},
"data": {
"remapper": {
"cos_sin": {
"d": ["cos_d", "sin_d"],
}
},
"imputer": {"default": "none", "mean": ["y", "d"]},
"forcing": ["z", "q"],
"diagnostic": ["other"],
"remapped": {
"d": ["cos_d", "sin_d"],
},
},
},
)
statistics = {
"mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0, 1.0]),
}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5}
data_indices = IndexCollection(config=config, name_to_index=name_to_index)
return InputImputer(config=config.data.imputer, data_indices=data_indices, statistics=statistics)


def test_remap_not_inplace(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
input_remapper(x, in_place=False)
Expand All @@ -66,3 +96,13 @@ def test_remap_inverse_transform(input_remapper) -> None:
assert torch.allclose(
input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x
)


def test_transform_loss_mask(input_imputer, input_remapper) -> None:
x = torch.Tensor([[1.0, np.nan, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, np.nan, 10.0]])
expected_output = torch.Tensor([[1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 0.0]])
input_imputer.transform(x)
input_remapper.transform(x)
loss_mask_training = input_imputer.loss_mask_training
loss_mask_training = input_remapper.transform_loss_mask(loss_mask_training)
assert torch.allclose(loss_mask_training, expected_output)
Loading