From eea2c26a6aafee1eb7f2f61468edad98893d70f9 Mon Sep 17 00:00:00 2001 From: Gabriel Moldovan Date: Fri, 2 Aug 2024 19:03:08 +0000 Subject: [PATCH] precommit changes --- .../models/models/encoder_processor_decoder.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 50cb03a..b9995c7 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,9 +21,8 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.models.bounding import BaseBoundingStrategy from anemoi.models.layers.graph import TrainableTensor - +from anemoi.models.models.bounding import BaseBoundingStrategy LOGGER = logging.getLogger(__name__) @@ -78,7 +77,9 @@ def create_bounding_strategy(config: DotDict) -> BaseBoundingStrategy: # Check if bounding_strategies_config is not None and not empty if bounding_strategies_config: - self.bounding_strategies = {var: create_bounding_strategy(cfg) for var, cfg in bounding_strategies_config.items()} + self.bounding_strategies = { + var: create_bounding_strategy(cfg) for var, cfg in bounding_strategies_config.items() + } else: self.bounding_strategies = {} @@ -266,7 +267,10 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # residual connection (just for the prognostic variables) x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] - for var, strategy in self.bounding_strategies.items(): # bounding performed in the order specified in the config file + for ( + var, + strategy, + ) in self.bounding_strategies.items(): # bounding performed in the order specified in the config file indices = [] indices.append(self.data_indices.model.output.name_to_index[var])