diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py index 44aae2d..07dfdfc 100644 --- a/src/anemoi/models/preprocessing/normalizer.py +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -133,7 +133,7 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min ], f"{method} is not a valid normalisation method" def transform( - self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + self, x: torch.Tensor, in_place: bool, data_index: torch.Tensor, ) -> torch.Tensor: """Normalizes an input tensor x of shape [..., nvars]. @@ -157,14 +157,10 @@ def transform( _description_ """ if not in_place: - x = x.clone() + x = x.clone() # TODO: fix this; implement a custom clone() op? - if data_index is not None: - x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index] - elif x.shape[-1] == len(self._input_idx): - x[..., :] = x[..., :] * self._norm_mul[self._input_idx] + self._norm_add[self._input_idx] - else: - x[..., :] = x[..., :] * self._norm_mul + self._norm_add + assert data_index is not None # [Mihai] we require a data_index + x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index] return x def inverse_transform(