Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mishooax committed Sep 27, 2024
1 parent 832c614 commit bb15293
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min
], f"{method} is not a valid normalisation method"

def transform(
self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None
self, x: torch.Tensor, in_place: bool, data_index: torch.Tensor,
) -> torch.Tensor:
"""Normalizes an input tensor x of shape [..., nvars].
Expand All @@ -157,14 +157,10 @@ def transform(
_description_
"""
if not in_place:
x = x.clone()
x = x.clone() # TODO: fix this; implement a custom clone() op?

if data_index is not None:
x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index]
elif x.shape[-1] == len(self._input_idx):
x[..., :] = x[..., :] * self._norm_mul[self._input_idx] + self._norm_add[self._input_idx]
else:
x[..., :] = x[..., :] * self._norm_mul + self._norm_add
assert data_index is not None # [Mihai] we require a data_index
x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index]
return x

def inverse_transform(
Expand Down

0 comments on commit bb15293

Please sign in to comment.