Skip to content

Commit

Permalink
ZipProsessors for post and pre proc tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 22, 2024
1 parent fd2bcf1 commit f391a03
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.preprocessing import Processors
from anemoi.models.preprocessing import ZipProcessors


class AnemoiModelInterface(torch.nn.Module):
Expand Down Expand Up @@ -111,3 +112,43 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
y_hat = self(x)

return self.post_processors(y_hat, in_place=False)

class FuserModelInterface(torch.nn.Module):

def __init__(
self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
) -> None:
super().__init__()
self.config = config
self.id = str(uuid.uuid4())
self.multi_step = self.config.training.multistep_input
self.graph_data = graph_data
self.statistics = statistics
self.metadata = metadata
self.data_indices = data_indices
self._build_model()

def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = tuple([
[name, instantiate(processor, data_indices=self.data_indices[i], statistics=self.statistics[i])]
for name, processor in dset_config.processors.items()
] for i, dset_config in enumerate(self.config.data.zip))

# Assign the processor list pre- and post-processors
self.pre_processors = ZipProcessors(processors)
self.post_processors = ZipProcessors(processors, inverse=True)
'''
# Instantiate the model
self.model = instantiate(
self.config.model.model,
model_config=self.config,
data_indices=self.data_indices,
graph_data=self.graph_data,
_recursive_=False, # Disables recursive instantiation by Hydra
)
# Use the forward method of the model directly
self.forward = self.model.forward
'''
24 changes: 24 additions & 0 deletions src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,27 @@ def _run_checks(self, x):
assert not torch.isnan(
x
).any(), f"NaNs ({torch.isnan(x).sum()}) found in processed tensor after {self.__class__.__name__}."

class ZipProcessors(nn.Module):

def __init__(self, processors_zip: tuple, inverse: bool = False) -> None:

super().__init__()
self.inverse = inverse
processors = [Processors(processors, inverse=inverse) for processors in processors_zip]
self.processors = nn.ModuleList(processors)

def __repr__(self) -> str:
return f"{self.__class__.__name__} [{'inverse' if self.inverse else 'forward'}]({self.processors})"

def forward(self, x: tuple, in_place: bool = True) -> tuple:
if not in_place:
y=()
for i, processor in enumerate(self.processors):
y += (processor(x[i], in_place=False),)
return y
else:
for i, processor in enumerate(self.processors):
x[i] = (processor(x[i], in_place=True),)
return x

0 comments on commit f391a03

Please sign in to comment.