Skip to content

Commit

Permalink
precommit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieloks committed Aug 2, 2024
1 parent 2418ca4 commit eea2c26
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.models.bounding import BaseBoundingStrategy
from anemoi.models.layers.graph import TrainableTensor

from anemoi.models.models.bounding import BaseBoundingStrategy
LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -78,7 +77,9 @@ def create_bounding_strategy(config: DotDict) -> BaseBoundingStrategy:

# Check if bounding_strategies_config is not None and not empty
if bounding_strategies_config:
self.bounding_strategies = {var: create_bounding_strategy(cfg) for var, cfg in bounding_strategies_config.items()}
self.bounding_strategies = {
var: create_bounding_strategy(cfg) for var, cfg in bounding_strategies_config.items()
}
else:
self.bounding_strategies = {}

Expand Down Expand Up @@ -266,7 +267,10 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for var, strategy in self.bounding_strategies.items(): # bounding performed in the order specified in the config file
for (
var,
strategy,
) in self.bounding_strategies.items(): # bounding performed in the order specified in the config file
indices = []
indices.append(self.data_indices.model.output.name_to_index[var])

Expand Down

0 comments on commit eea2c26

Please sign in to comment.