From 792c032c7c3a6fbb508cf8b7cdbe42f795d99ad2 Mon Sep 17 00:00:00 2001 From: "rilwan.adewoyin@ecmwf.int" Date: Sat, 10 Aug 2024 10:12:59 +0000 Subject: [PATCH 1/8] WIP: Adding tendency training --- src/anemoi/models/interface/__init__.py | 67 +++++++++++++++++++------ 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 54c548df..291bff55 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -17,6 +17,7 @@ from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec from anemoi.models.preprocessing import Processors +from typing import Optional class AnemoiModelInterface(torch.nn.Module): """An interface for Anemoi models. @@ -49,29 +50,44 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict, tendency_statistics: Optional[dict] = None ) -> None: super().__init__() self.config = config self.id = str(uuid.uuid4()) self.multi_step = self.config.training.multistep_input + self.flag_tendency = self.config.training.flag_tendency self.graph_data = graph_data self.statistics = statistics + self.tendency_statistics = tendency_statistics self.metadata = metadata self.data_indices = data_indices self._build_model() def _build_model(self) -> None: """Builds the model and pre- and post-processors.""" - # Instantiate processors - processors = [ + # Instantiate processors for state + processors_state = [ [name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)] - for name, processor in self.config.data.processors.items() + for name, processor in self.config.data.processors.state.items() ] # Assign the processor list pre- and post-processors - self.pre_processors = Processors(processors) - self.post_processors = Processors(processors, inverse=True) + self.pre_processors_state = Processors(processors_state) + self.post_processors_state = Processors(processors_state, inverse=True) + + # Instantiate processors for tendency + self.pre_processors_tendency = None + self.post_processors_tendency = None + if self.flag_tendency: + processors_tendency = [ + [name, instantiate(processor, statistics=self.tendency_statistics, data_indices=self.data_indices)] + for name, processor in self.config.data.processors.tendency.items() + ] + + self.pre_processors_tendency = Processors(processors_tendency) + self.post_processors_tendency = Processors(processors_tendency, inverse=True) + # Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec) self.model = AnemoiModelEncProcDec( @@ -94,17 +110,40 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: torch.Tensor Predicted data. """ - batch = self.pre_processors(batch, in_place=False) - + with torch.no_grad(): - + assert ( len(batch.shape) == 4 ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" - # Dimensions are - # batch, timesteps, horizonal space, variables - x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index - y_hat = self(x) + x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False) + + # Dimensions are + # batch, timesteps, horizontal space, variables + x = x[..., None, :] # add dummy ensemble dimension as 3rd index + + #NOTE: TENDENCY CHANGES HERE + if not self.flag_tendency: + y_hat = self(x) + y_hat = self.post_processors(y_hat, in_place=False) + else: + tendency_hat = self(x) + y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...] , tendency_hat ) + + return y_hat + + def add_tendency_to_state(self, state_in, tendency_in): + #NOTE: TENDENCY CHANGES HERE + state_new = self.post_processors_tendency(tendency_in, in_place=False, data_index=self.data_indices.data.output.full) + # diagnostic fields are denormallised as full fields + # NOTE: TENDENCY ISSUE - Explain this ??? + state_new[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( + tendency_in[..., self.data_indices.model.output.diagnostic], + in_place=False, + data_index=self.data_indices.data.output.diagnostic, + ) + # add state for prognostic variables + state_new[..., self.data_indices.model.output.prognostic] += state_in[..., self.data_indices.model.input.prognostic] - return self.post_processors(y_hat, in_place=False) + return state_new \ No newline at end of file From fcb2a1e4d4124dcd57202f6776af5da67b67eab9 Mon Sep 17 00:00:00 2001 From: "rilwan.adewoyin@ecmwf.int" Date: Tue, 13 Aug 2024 19:37:44 +0000 Subject: [PATCH 2/8] #22 added logic to enable tendency based training and prediction --- src/anemoi/models/interface/__init__.py | 52 +++++++++++-------- src/anemoi/models/preprocessing/__init__.py | 12 +++-- src/anemoi/models/preprocessing/imputer.py | 15 ++++-- .../test_preprocessor_normalizer.py | 2 +- 4 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 291bff55..ea2072c8 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -41,25 +41,21 @@ class AnemoiModelInterface(torch.nn.Module): Metadata for the model. data_indices : dict Indices for the data. - pre_processors : Processors - Pre-processing steps to apply to the data before passing it to the model. - post_processors : Processors - Post-processing steps to apply to the model's output. model : AnemoiModelEncProcDec The underlying Anemoi model. """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict, tendency_statistics: Optional[dict] = None + self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict, statistics_tendencies: Optional[dict] = None ) -> None: super().__init__() self.config = config self.id = str(uuid.uuid4()) self.multi_step = self.config.training.multistep_input - self.flag_tendency = self.config.training.flag_tendency + self.tendency_mode = self.config.training.tendency_mode self.graph_data = graph_data self.statistics = statistics - self.tendency_statistics = tendency_statistics + self.statistics_tendencies = statistics_tendencies self.metadata = metadata self.data_indices = data_indices self._build_model() @@ -79,9 +75,9 @@ def _build_model(self) -> None: # Instantiate processors for tendency self.pre_processors_tendency = None self.post_processors_tendency = None - if self.flag_tendency: + if self.tendency_mode: processors_tendency = [ - [name, instantiate(processor, statistics=self.tendency_statistics, data_indices=self.data_indices)] + [name, instantiate(processor, statistics=self.statistics_tendencies, data_indices=self.data_indices)] for name, processor in self.config.data.processors.tendency.items() ] @@ -123,27 +119,39 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: # batch, timesteps, horizontal space, variables x = x[..., None, :] # add dummy ensemble dimension as 3rd index - #NOTE: TENDENCY CHANGES HERE - if not self.flag_tendency: + if not self.tendency_mode: y_hat = self(x) - y_hat = self.post_processors(y_hat, in_place=False) + y_hat = self.post_processors_state(y_hat, in_place=False) else: tendency_hat = self(x) y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...] , tendency_hat ) return y_hat - def add_tendency_to_state(self, state_in, tendency_in): - #NOTE: TENDENCY CHANGES HERE - state_new = self.post_processors_tendency(tendency_in, in_place=False, data_index=self.data_indices.data.output.full) - # diagnostic fields are denormallised as full fields - # NOTE: TENDENCY ISSUE - Explain this ??? - state_new[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( - tendency_in[..., self.data_indices.model.output.diagnostic], + def add_tendency_to_state(self, state_inp, tendency): + """Add the tendency to the state. + + Parameters + ---------- + state_inp : torch.Tensor + The input state tensor with full input variables and unprocessed. + tendency : torch.Tensor + The tendency tensor output from model. + + Returns + ------- + torch.Tensor + Predicted data. + """ + + state_outp = self.post_processors_tendency(tendency, in_place=False, data_index=self.data_indices.data.output.full) + + state_outp[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( + tendency[..., self.data_indices.model.output.diagnostic], in_place=False, data_index=self.data_indices.data.output.diagnostic, ) - # add state for prognostic variables - state_new[..., self.data_indices.model.output.prognostic] += state_in[..., self.data_indices.model.input.prognostic] - return state_new \ No newline at end of file + state_outp[..., self.data_indices.model.output.prognostic] += state_inp[..., self.data_indices.model.input.prognostic] + + return state_outp \ No newline at end of file diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 081afafa..e0bd2aba 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -75,7 +75,7 @@ def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[st for variable in variables } - def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: + def forward(self, x, in_place: bool = True, inverse: bool = False, data_index: Optional[torch.Tensor] = None ) -> Tensor: """Process the input tensor. Parameters @@ -93,8 +93,8 @@ def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: Processed tensor """ if inverse: - return self.inverse_transform(x, in_place=in_place) - return self.transform(x, in_place=in_place) + return self.inverse_transform(x, in_place=in_place, data_index=data_index) + return self.transform(x, in_place=in_place, data_index=data_index) def transform(self, x, in_place: bool = True) -> Tensor: """Process the input tensor.""" @@ -135,7 +135,7 @@ def __init__(self, processors: list, inverse: bool = False) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__} [{'inverse' if self.inverse else 'forward'}]({self.processors})" - def forward(self, x, in_place: bool = True) -> Tensor: + def forward(self, x, in_place: bool = True, data_index: Optional[torch.Tensor] = None) -> Tensor: """Process the input tensor. Parameters @@ -144,6 +144,8 @@ def forward(self, x, in_place: bool = True) -> Tensor: Input tensor in_place : bool Whether to process the tensor in place + data_index : Optional[torch.Tensor], optional + Normalize only the specified indices, by default None Returns ------- @@ -151,7 +153,7 @@ def forward(self, x, in_place: bool = True) -> Tensor: Processed tensor """ for processor in self.processors.values(): - x = processor(x, in_place=in_place, inverse=self.inverse) + x = processor(x, in_place=in_place, inverse=self.inverse, data_index=data_index) if self.first_run: self.first_run = False diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index a7b0a8ab..17539436 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -104,7 +104,7 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) - def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + def transform(self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() @@ -116,7 +116,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: self.nan_locations = torch.isnan(x[idx].squeeze()) # Choose correct index based on number of variables - if x.shape[-1] == self.num_training_input_vars: + if data_index is not None: + index = data_index + elif x.shape[-1] == self.num_training_input_vars: index = self.index_training_input elif x.shape[-1] == self.num_inference_input_vars: index = self.index_inference_input @@ -132,13 +134,16 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value return x - def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + def inverse_transform(self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() # Replace original nans with nan again - if x.shape[-1] == self.num_training_output_vars: + # Choose correct index based on number of variables + if data_index is not None: + index = data_index + elif x.shape[-1] == self.num_training_output_vars: index = self.index_training_output elif x.shape[-1] == self.num_inference_output_vars: index = self.index_inference_output @@ -147,7 +152,7 @@ def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Ten f"Input tensor ({x.shape[-1]}) does not match the training " f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", ) - + # Replace values for idx_src, idx_dst in zip(self.index_training_input, index): if idx_dst is not None: diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index 787079d8..8244d185 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -36,7 +36,7 @@ def input_normalizer(): } name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} data_indices = IndexCollection(config=config, name_to_index=name_to_index) - return InputNormalizer(config=config.data.normalizer, statistics=statistics, data_indices=data_indices) + return InputNormalizer(config=config.data.normalizers.state, statistics=statistics, data_indices=data_indices) def test_normalizer_not_inplace(input_normalizer) -> None: From 03b9603d0be2bb1199e6cf36d9d48c4ffe5ad052 Mon Sep 17 00:00:00 2001 From: "rilwan.adewoyin@ecmwf.int" Date: Tue, 13 Aug 2024 20:04:54 +0000 Subject: [PATCH 3/8] ran pre-commit --- src/anemoi/models/interface/__init__.py | 34 +++++++++++++-------- src/anemoi/models/preprocessing/__init__.py | 4 ++- src/anemoi/models/preprocessing/imputer.py | 10 ++++-- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index ea2072c8..d984c8a9 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -8,6 +8,7 @@ # import uuid +from typing import Optional import torch from anemoi.utils.config import DotDict @@ -17,7 +18,6 @@ from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec from anemoi.models.preprocessing import Processors -from typing import Optional class AnemoiModelInterface(torch.nn.Module): """An interface for Anemoi models. @@ -46,7 +46,14 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict, statistics_tendencies: Optional[dict] = None + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + statistics_tendencies: Optional[dict] = None, ) -> None: super().__init__() self.config = config @@ -84,7 +91,6 @@ def _build_model(self) -> None: self.pre_processors_tendency = Processors(processors_tendency) self.post_processors_tendency = Processors(processors_tendency, inverse=True) - # Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec) self.model = AnemoiModelEncProcDec( config=self.config, data_indices=self.data_indices, graph_data=self.graph_data @@ -106,9 +112,9 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: torch.Tensor Predicted data. """ - + with torch.no_grad(): - + assert ( len(batch.shape) == 4 ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" @@ -124,27 +130,29 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: y_hat = self.post_processors_state(y_hat, in_place=False) else: tendency_hat = self(x) - y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...] , tendency_hat ) + y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...], tendency_hat) return y_hat - + def add_tendency_to_state(self, state_inp, tendency): """Add the tendency to the state. - + Parameters ---------- state_inp : torch.Tensor The input state tensor with full input variables and unprocessed. tendency : torch.Tensor The tendency tensor output from model. - + Returns ------- torch.Tensor Predicted data. """ - state_outp = self.post_processors_tendency(tendency, in_place=False, data_index=self.data_indices.data.output.full) + state_outp = self.post_processors_tendency( + tendency, in_place=False, data_index=self.data_indices.data.output.full + ) state_outp[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( tendency[..., self.data_indices.model.output.diagnostic], @@ -152,6 +160,8 @@ def add_tendency_to_state(self, state_inp, tendency): data_index=self.data_indices.data.output.diagnostic, ) - state_outp[..., self.data_indices.model.output.prognostic] += state_inp[..., self.data_indices.model.input.prognostic] + state_outp[..., self.data_indices.model.output.prognostic] += state_inp[ + ..., self.data_indices.model.input.prognostic + ] - return state_outp \ No newline at end of file + return state_outp diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index e0bd2aba..b010a7b4 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -75,7 +75,9 @@ def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[st for variable in variables } - def forward(self, x, in_place: bool = True, inverse: bool = False, data_index: Optional[torch.Tensor] = None ) -> Tensor: + def forward( + self, x, in_place: bool = True, inverse: bool = False, data_index: Optional[torch.Tensor] = None + ) -> Tensor: """Process the input tensor. Parameters diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 17539436..c0701527 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -104,7 +104,9 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) - def transform(self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None ) -> torch.Tensor: + def transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() @@ -134,7 +136,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True, data_index: Optiona x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value return x - def inverse_transform(self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None ) -> torch.Tensor: + def inverse_transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() @@ -152,7 +156,7 @@ def inverse_transform(self, x: torch.Tensor, in_place: bool = True, data_index: f"Input tensor ({x.shape[-1]}) does not match the training " f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", ) - + # Replace values for idx_src, idx_dst in zip(self.index_training_input, index): if idx_dst is not None: From 7d973a473ecc72386924e79b421921e2d16d29a4 Mon Sep 17 00:00:00 2001 From: "rilwan.adewoyin@ecmwf.int" Date: Tue, 13 Aug 2024 20:14:06 +0000 Subject: [PATCH 4/8] updated the test_proeporcessor_normalizer.py --- tests/preprocessing/test_preprocessor_normalizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index 8244d185..10704b9c 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -22,7 +22,9 @@ def input_normalizer(): { "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, "data": { - "normalizer": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}, + "normalizers": { + "state": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]} + }, "forcing": ["z", "q"], "diagnostic": ["other"], }, From f94e46c59f3c1fa9481f0aa754069be842c9cc83 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Tue, 27 Aug 2024 14:31:27 +0000 Subject: [PATCH 5/8] Move residual from EncDecProcessor to Interface. --- src/anemoi/models/interface/__init__.py | 12 ++++++++++-- .../models/models/encoder_processor_decoder.py | 2 -- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index d984c8a9..045b966b 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -13,6 +13,7 @@ import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from torch.distributed.distributed_c10d import ProcessGroup from torch_geometric.data import HeteroData from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec @@ -96,8 +97,15 @@ def _build_model(self) -> None: config=self.config, data_indices=self.data_indices, graph_data=self.graph_data ) - # Use the forward method of the model directly - self.forward = self.model.forward + def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor: + if self.tendency_mode: + # Predict tendency + x_pred = self.model.forward(x, model_comm_group) + else: + # Predict state by adding residual connection (just for the prognostic variables) + x_pred = self.model.forward(x, model_comm_group) + x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx] + return x_pred def predict_step(self, batch: torch.Tensor) -> torch.Tensor: """Prediction step for the model. diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 0f374742..5cda6e17 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -248,6 +248,4 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> .clone() ) - # residual connection (just for the prognostic variables) - x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] return x_out From 7d49c61f2fa4e2486149413cdced69602612162c Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Tue, 24 Sep 2024 08:49:48 +0000 Subject: [PATCH 6/8] Change keyword to prediction_strategy --- src/anemoi/models/interface/__init__.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 045b966b..99919cd5 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -60,7 +60,7 @@ def __init__( self.config = config self.id = str(uuid.uuid4()) self.multi_step = self.config.training.multistep_input - self.tendency_mode = self.config.training.tendency_mode + self.prediction_strategy = self.config.training.prediction_strategy self.graph_data = graph_data self.statistics = statistics self.statistics_tendencies = statistics_tendencies @@ -83,7 +83,7 @@ def _build_model(self) -> None: # Instantiate processors for tendency self.pre_processors_tendency = None self.post_processors_tendency = None - if self.tendency_mode: + if self.prediction_strategy == "tendency": processors_tendency = [ [name, instantiate(processor, statistics=self.statistics_tendencies, data_indices=self.data_indices)] for name, processor in self.config.data.processors.tendency.items() @@ -98,13 +98,13 @@ def _build_model(self) -> None: ) def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor: - if self.tendency_mode: - # Predict tendency - x_pred = self.model.forward(x, model_comm_group) - else: + if self.prediction_strategy == "residual": # Predict state by adding residual connection (just for the prognostic variables) x_pred = self.model.forward(x, model_comm_group) x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx] + else: + x_pred = self.model.forward(x, model_comm_group) + return x_pred def predict_step(self, batch: torch.Tensor) -> torch.Tensor: @@ -133,12 +133,12 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: # batch, timesteps, horizontal space, variables x = x[..., None, :] # add dummy ensemble dimension as 3rd index - if not self.tendency_mode: - y_hat = self(x) - y_hat = self.post_processors_state(y_hat, in_place=False) - else: + if self.prediction_strategy == "tendency": tendency_hat = self(x) y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...], tendency_hat) + else: + y_hat = self(x) + y_hat = self.post_processors_state(y_hat, in_place=False) return y_hat From 5ae42df08ee37e57ca444de6da6c030bf4807f2c Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 27 Sep 2024 13:22:03 +0000 Subject: [PATCH 7/8] Fix predict_step for inference. --- src/anemoi/models/interface/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 8b3ef82a..7c366825 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -133,16 +133,14 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: assert ( len(batch.shape) == 4 ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" - x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False) # Dimensions are # batch, timesteps, horizontal space, variables - x = x[..., None, :] # add dummy ensemble dimension as 3rd index - + x = x[..., None, :, :] # add dummy ensemble dimension as 3rd index if self.prediction_strategy == "tendency": tendency_hat = self(x) - y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...], tendency_hat) + y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat) else: y_hat = self(x) y_hat = self.post_processors_state(y_hat, in_place=False) From 2aca33041649c5f79e9f690f7fe7e5f6cf86d5eb Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Thu, 31 Oct 2024 10:09:15 +0000 Subject: [PATCH 8/8] Fix bug in inference. --- src/anemoi/models/interface/__init__.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 7c366825..7c8be19b 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -133,17 +133,19 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: assert ( len(batch.shape) == 4 ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" - x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False) + x = self.pre_processors_state( + batch[:, 0 : self.multi_step, ...], in_place=False, data_index=self.data_indices.data.input.full + ) # Dimensions are # batch, timesteps, horizontal space, variables - x = x[..., None, :, :] # add dummy ensemble dimension as 3rd index + x = x[:, :, None, ...] # add dummy ensemble dimension as 3rd index if self.prediction_strategy == "tendency": tendency_hat = self(x) y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat) else: y_hat = self(x) - y_hat = self.post_processors_state(y_hat, in_place=False) + y_hat = self.post_processors_state(y_hat, in_place=False, data_index=self.data_indices.data.output.full) return y_hat @@ -153,9 +155,9 @@ def add_tendency_to_state(self, state_inp: torch.Tensor, tendency: torch.Tensor) Parameters ---------- state_inp : torch.Tensor - The input state tensor with full input variables and unprocessed. + The normalized input state tensor with full input variables. tendency : torch.Tensor - The tendency tensor output from model. + The normalized tendency tensor output from model. Returns ------- @@ -173,8 +175,10 @@ def add_tendency_to_state(self, state_inp: torch.Tensor, tendency: torch.Tensor) data_index=self.data_indices.data.output.diagnostic, ) - state_outp[..., self.data_indices.model.output.prognostic] += state_inp[ - ..., self.data_indices.model.input.prognostic - ] + state_outp[..., self.data_indices.model.output.prognostic] += self.post_processors_state( + state_inp[..., self.data_indices.model.input.prognostic], + in_place=False, + data_index=self.data_indices.data.input.prognostic, + ) return state_outp