Skip to content

Commit

Permalink
fix: Allow updates to Scalars, without exclusion and more(ecmwf#137)
Browse files Browse the repository at this point in the history
* fix: Allow updates to scalars
* Add 'add' & 'update'
* Add without & without_by_dim
* Rework loss functions to use without
- Allow limiting of scalars rather than turning off
* Rework feature_indices to scalar_indices
* Remove without in validation_metrics
  • Loading branch information
HCookie authored Nov 14, 2024
1 parent cce60f0 commit 76d3ef6
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 90 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you!
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)
- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137)
- Add without subsetting in ScaleTensor
- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63)
- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92)
- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/)
Expand Down
17 changes: 9 additions & 8 deletions src/anemoi/training/losses/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def forward(
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted Huber loss.
Expand All @@ -86,10 +86,11 @@ def forward(
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None
Returns
-------
Expand All @@ -98,6 +99,6 @@ def forward(
"""
out = self.huber(pred, target)

if feature_scale:
out = self.scale_by_variable_scaling(out, feature_indices)
out = self.scale(out, scalar_indices, without_scalars=without_scalars)

return self.scale_by_node_weights(out, squash)
17 changes: 8 additions & 9 deletions src/anemoi/training/losses/logcosh.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def forward(
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted LogCosh loss.
Expand All @@ -80,10 +80,11 @@ def forward(
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None
Returns
-------
Expand All @@ -92,7 +93,5 @@ def forward(
"""
out = LogCosh.apply(pred - target)

if feature_scale:
out = self.scale(out, feature_indices)
out = self.scale(out, scalar_indices, without_scalars=without_scalars)
return self.scale_by_node_weights(out, squash)
18 changes: 9 additions & 9 deletions src/anemoi/training/losses/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def forward(
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted MAE loss.
Expand All @@ -66,18 +66,18 @@ def forward(
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None
Returns
-------
torch.Tensor
Weighted MAE loss
"""
out = torch.abs(pred - target)

if feature_scale:
out = self.scale(out, feature_indices)
out = self.scale(out, scalar_indices, without_scalars=without_scalars)
return self.scale_by_node_weights(out, squash)
17 changes: 8 additions & 9 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def forward(
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted MSE loss.
Expand All @@ -64,18 +64,17 @@ def forward(
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None
Returns
-------
torch.Tensor
Weighted MSE loss
"""
out = torch.square(pred - target)

if feature_scale:
out = self.scale(out, feature_indices)
out = self.scale(out, scalar_indices, without_scalars=without_scalars)
return self.scale_by_node_weights(out, squash)
17 changes: 9 additions & 8 deletions src/anemoi/training/losses/rmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def forward(
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted RMSE loss.
Expand All @@ -63,10 +63,11 @@ def forward(
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None
Returns
-------
Expand All @@ -77,7 +78,7 @@ def forward(
pred=pred,
target=target,
squash=squash,
feature_indices=feature_indices,
feature_scale=feature_scale,
scalar_indices=scalar_indices,
without_scalars=without_scalars,
)
return torch.sqrt(mse)
Loading

0 comments on commit 76d3ef6

Please sign in to comment.