diff --git a/CHANGELOG.md b/CHANGELOG.md index ec4db72..7811ef3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,26 +10,31 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) -## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables - ### Added - -- CI workflow to update the changelog on release -- configurabilty of the dropout probability in the the MultiHeadSelfAttention module -- CI workflow to update the changelog on release -- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords. - Codeowners file - Pygrep precommit hooks - Docsig precommit hooks - Changelog merge strategy +- configurabilty of the dropout probability in the the MultiHeadSelfAttention module +- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) + +### Changed +- Bugfixes for CI + +### Removed + +## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables +### Added + +- CI workflow to update the changelog on release +- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords. ### Changed - Update CI to inherit from common infrastructue reusable workflows - run downstream-ci only when src and tests folders have changed - New error messages for wrongs graphs. -- Bugfixes for CI ### Removed diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py new file mode 100644 index 0000000..3791ff2 --- /dev/null +++ b/src/anemoi/models/layers/bounding.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import torch +from torch import nn + +from anemoi.models.data_indices.tensor import InputTensorIndex + + +class BaseBounding(nn.Module, ABC): + """Abstract base class for bounding strategies. + + This class defines an interface for bounding strategies which are used to apply a specific + restriction to the predictions of a model. + """ + + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + ) -> None: + super().__init__() + + self.name_to_index = name_to_index + self.variables = variables + self.data_index = self._create_index(variables=self.variables) + + def _create_index(self, variables: list[str]) -> InputTensorIndex: + return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the bounding to the predictions. + + Parameters + ---------- + x : torch.Tensor + The tensor containing the predictions that will be bounded. + + Returns + ------- + torch.Tensor + A tensor with the bounding applied. + """ + pass + + +class ReluBounding(BaseBounding): + """Initializes the bounding with a ReLU activation / zero clamping.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) + return x + + +class HardtanhBounding(BaseBounding): + """Initializes the bounding with specified minimum and maximum values for bounding. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + min_val : float + The minimum value for the HardTanh activation. + max_val : float + The maximum value for the HardTanh activation. + """ + + def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None: + super().__init__(variables=variables, name_to_index=name_to_index) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = torch.nn.functional.hardtanh( + x[..., self.data_index], min_val=self.min_val, max_val=self.max_val + ) + return x + + +class FractionBounding(HardtanhBounding): + """Initializes the FractionBounding with specified parameters. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + min_val : float + The minimum value for the HardTanh activation. + max_val : float + The maximum value for the HardTanh activation. + total_var : str + A string representing a variable from which a secondary variable is derived. For + example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). + """ + + def __init__( + self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str + ) -> None: + super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) + self.total_variable = self._create_index(variables=[total_var]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply the HardTanh bounding to the data_index variables + x = super().forward(x) + # Calculate the fraction of the total variable + x[..., self.data_index] *= x[..., self.total_variable] + return x diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 3414dc5..b043b0c 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -67,6 +67,8 @@ def __init__( self._register_latlon("data", self._graph_name_data) self._register_latlon("hidden", self._graph_name_hidden) + self.data_indices = data_indices + self.num_channels = config.model.num_channels input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size @@ -103,6 +105,14 @@ def __init__( dst_grid_size=self._data_grid_size, ) + # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) + self.boundings = nn.ModuleList( + [ + instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index) + for cfg in getattr(config.model, "bounding", []) + ] + ) + def _calculate_shapes_and_indices(self, data_indices: dict) -> None: self.num_input_channels = len(data_indices.internal_model.input) self.num_output_channels = len(data_indices.internal_model.output) @@ -251,4 +261,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # residual connection (just for the prognostic variables) x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for bounding in self.boundings: + # bounding performed in the order specified in the config file + x_out = bounding(x_out) + return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 53017fb..cc2cb4f 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -38,7 +38,23 @@ def __init__( Data indices for input and output variables statistics : dict Data statistics dictionary + data_indices : dict + Data indices for input and output variables + + Attributes + ---------- + default : str + Default method for variables not specified in the config + method_config : dict + Dictionary of the methods with lists of variables + methods : dict + Dictionary of the variables with methods + data_indices : IndexCollection + Data indices for input and output variables + remap : dict + Dictionary of the variables with remapped names in the config """ + super().__init__() self.default, self.method_config = self._process_config(config) @@ -47,8 +63,10 @@ def __init__( self.data_indices = data_indices def _process_config(self, config): + _special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method. default = config.get("default", "none") - method_config = {k: v for k, v in config.items() if k != "default" and v is not None and v != "none"} + self.remap = config.get("remap", {}) + method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"} if not method_config: LOGGER.warning( diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py index bc75466..ee6a4f5 100644 --- a/src/anemoi/models/preprocessing/normalizer.py +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -49,6 +49,16 @@ def __init__( mean = statistics["mean"] stdev = statistics["stdev"] + # Optionally reuse statistic of one variable for another variable + statistics_remap = {} + for remap, source in self.remap.items(): + idx_src, idx_remap = name_to_index_training_input[source], name_to_index_training_input[remap] + statistics_remap[idx_remap] = (minimum[idx_src], maximum[idx_src], mean[idx_src], stdev[idx_src]) + + # Two-step to avoid overwriting the original statistics in the loop (this reduces dependence on order) + for idx, new_stats in statistics_remap.items(): + minimum[idx], maximum[idx], mean[idx], stdev[idx] = new_stats + self._validate_normalization_inputs(name_to_index_training_input, minimum, maximum, mean, stdev) _norm_add = np.zeros((minimum.size,), dtype=np.float32) @@ -56,6 +66,7 @@ def __init__( for name, i in name_to_index_training_input.items(): method = self.methods.get(name, self.default) + if method == "mean-std": LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.") if stdev[i] < (mean[i] * 1e-6): @@ -63,6 +74,13 @@ def __init__( _norm_mul[i] = 1 / stdev[i] _norm_add[i] = -mean[i] / stdev[i] + elif method == "std": + LOGGER.debug(f"Normalizing: {name} is std-normalised.") + if stdev[i] < (mean[i] * 1e-6): + warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}") + _norm_mul[i] = 1 / stdev[i] + _norm_add[i] = 0 + elif method == "min-max": LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].") x = maximum[i] - minimum[i] @@ -92,16 +110,20 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) " f"and entries in config ({sum(len(v) for v in self.method_config)}) do not match." ) + + # Check that all sizes align n = minimum.size assert maximum.size == n, (maximum.size, n) assert mean.size == n, (mean.size, n) assert stdev.size == n, (stdev.size, n) + # Check for typos in method config assert isinstance(self.methods, dict) for name, method in self.methods.items(): assert name in name_to_index_training_input, f"{name} is not a valid variable name" assert method in [ "mean-std", + "std", # "robust", "min-max", "max", diff --git a/tests/layers/test_bounding.py b/tests/layers/test_bounding.py new file mode 100644 index 0000000..87619cd --- /dev/null +++ b/tests/layers/test_bounding.py @@ -0,0 +1,92 @@ +import pytest +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate + +from anemoi.models.layers.bounding import FractionBounding +from anemoi.models.layers.bounding import HardtanhBounding +from anemoi.models.layers.bounding import ReluBounding + + +@pytest.fixture +def config(): + return DotDict({"variables": ["var1", "var2"], "total_var": "total_var"}) + + +@pytest.fixture +def name_to_index(): + return {"var1": 0, "var2": 1, "total_var": 2} + + +@pytest.fixture +def input_tensor(): + return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + + +def test_relu_bounding(config, name_to_index, input_tensor): + bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_hardtanh_bounding(config, name_to_index, input_tensor): + minimum, maximum = -1.0, 1.0 + bounding = HardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_fraction_bounding(config, name_to_index, input_tensor): + bounding = FractionBounding( + variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var + ) + output = bounding(input_tensor.clone()) + expected_output = torch.tensor([[0.0, 3.0, 3.0], [6.0, 0.0, 6.0], [0.25, 0.25, 0.5]]) + + assert torch.equal(output, expected_output) + + +def test_multi_chained_bounding(config, name_to_index, input_tensor): + # Apply Relu first on the first variable only + bounding1 = ReluBounding(variables=config.variables[:-1], name_to_index=name_to_index) + expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + # Check intemediate result + assert torch.equal(bounding1(input_tensor.clone()), expected_output) + minimum, maximum = 0.5, 1.75 + bounding2 = HardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + # Use full chaining on the input tensor + output = bounding2(bounding1(input_tensor.clone())) + # Data with Relu applied first and then Hardtanh + expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + + +def test_hydra_instantiate_bounding(config, name_to_index, input_tensor): + layer_definitions = [ + { + "_target_": "anemoi.models.layers.bounding.ReluBounding", + "variables": config.variables, + }, + { + "_target_": "anemoi.models.layers.bounding.HardtanhBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + }, + { + "_target_": "anemoi.models.layers.bounding.FractionBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + "total_var": config.total_var, + }, + ] + for layer_definition in layer_definitions: + bounding = instantiate(layer_definition, name_to_index=name_to_index) + bounding(input_tensor.clone()) diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index cc527e7..8056865 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -40,6 +40,37 @@ def input_normalizer(): return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics) +@pytest.fixture() +def remap_normalizer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "normalizer": { + "default": "mean-std", + "remap": {"x": "z", "y": "x"}, + "min-max": ["x"], + "max": ["y"], + "none": ["z"], + "mean-std": ["q"], + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + "remapped": {}, + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5, 1, 14]), + "minimum": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0, 10.0, 10.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputNormalizer(config=config.data.normalizer, statistics=statistics, data_indices=data_indices) + + def test_normalizer_not_inplace(input_normalizer) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) input_normalizer(x, in_place=False) @@ -87,3 +118,15 @@ def test_normalize_inverse_transform(input_normalizer) -> None: assert torch.allclose( input_normalizer.inverse_transform(input_normalizer.transform(x, in_place=False), in_place=False), x ) + + +def test_normalizer_not_inplace_remap(remap_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + remap_normalizer(x, in_place=False) + assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])) + + +def test_normalize_remap(remap_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + expected_output = torch.Tensor([[0.0, 2 / 11, 3.0, -0.5, 1 / 7], [5 / 9, 7 / 11, 8.0, 4.5, 0.5]]) + assert torch.allclose(remap_normalizer.transform(x), expected_output)