Skip to content

Commit

Permalink
Refactor node attributes (ecmwf#64)
Browse files Browse the repository at this point in the history
* feat: add NamedNodeAttributes

* feat: use NamedNodesAttributes in AnemoiModelEncProcDec

* fix: update changelog

* fix: add tests

* feat: drop unused attrs + type hints

* fix: homogeneise

* fix: update docstring

* fix: more docstrings
  • Loading branch information
JPXKQX authored Nov 11, 2024
1 parent e9c8172 commit 8874571
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you!
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)

### Changed
Expand Down
72 changes: 71 additions & 1 deletion src/anemoi/models/layers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch import Tensor
from torch import nn
from torch_geometric.data import HeteroData


class TrainableTensor(nn.Module):
Expand All @@ -36,8 +37,77 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None:
def forward(self, x: Tensor, batch_size: int) -> Tensor:
latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)]
if self.trainable is not None:
latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size))
latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size))
return torch.cat(
latent,
dim=-1, # feature dimension
)


class NamedNodesAttributes(nn.Module):
"""Named Nodes Attributes information.
Attributes
----------
num_nodes : dict[str, int]
Number of nodes for each group of nodes.
attr_ndims : dict[str, int]
Total dimension of node attributes (non-trainable + trainable) for each group of nodes.
trainable_tensors : nn.ModuleDict
Dictionary of trainable tensors for each group of nodes.
Methods
-------
get_coordinates(self, name: str) -> Tensor
Get the coordinates of a set of nodes.
forward( self, name: str, batch_size: int) -> Tensor
Get the node attributes to be passed trough the graph neural network.
"""

num_nodes: dict[str, int]
attr_ndims: dict[str, int]
trainable_tensors: dict[str, TrainableTensor]

def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None:
"""Initialize NamedNodesAttributes."""
super().__init__()

self.define_fixed_attributes(graph_data, num_trainable_params)

self.trainable_tensors = nn.ModuleDict()
for nodes_name, nodes in graph_data.node_items():
self.register_coordinates(nodes_name, nodes.x)
self.register_tensor(nodes_name, num_trainable_params)

def define_fixed_attributes(self, graph_data: HeteroData, num_trainable_params: int) -> None:
"""Define fixed attributes."""
nodes_names = list(graph_data.node_types)
self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in nodes_names}
self.attr_ndims = {
nodes_name: 2 * graph_data[nodes_name].x.shape[1] + num_trainable_params for nodes_name in nodes_names
}

def register_coordinates(self, name: str, node_coords: Tensor) -> None:
"""Register coordinates."""
sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def get_coordinates(self, name: str) -> Tensor:
"""Return original coordinates."""
sin_cos_coords = getattr(self, f"latlons_{name}")
ndim = sin_cos_coords.shape[1] // 2
sin_values = sin_cos_coords[:, :ndim]
cos_values = sin_cos_coords[:, ndim:]
return torch.atan2(sin_values, cos_values)

def register_tensor(self, name: str, num_trainable_params: int) -> None:
"""Register a trainable tensor."""
self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], num_trainable_params)

def forward(self, name: str, batch_size: int) -> Tensor:
"""Returns the node attributes to be passed trough the graph neural network.
It includes both the coordinates and the trainable parameters.
"""
latlons = getattr(self, f"latlons_{name}")
return self.trainable_tensors[name](latlons, batch_size)
65 changes: 14 additions & 51 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.graph import NamedNodesAttributes

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,42 +56,33 @@ def __init__(

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon of nodes
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden],
hidden_dim=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Decoder hidden -> data
Expand All @@ -102,8 +93,8 @@ def __init__(
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._data_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
Expand Down Expand Up @@ -133,34 +124,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
self._internal_output_idx,
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.
Parameters
----------
name : str
Name to store the lat-lon coordinates of the nodes.
nodes : str
Name of nodes to map
"""
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
)

def _run_mapper(
self,
mapper: nn.Module,
Expand Down Expand Up @@ -210,12 +173,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
self.node_attributes(self._graph_name_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)
x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)
Expand Down
79 changes: 79 additions & 0 deletions tests/layers/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
# nor does it submit to any jurisdiction.


import einops
import numpy as np
import pytest
import torch
from torch import nn
from torch_geometric.data import HeteroData

from anemoi.models.layers.graph import NamedNodesAttributes
from anemoi.models.layers.graph import TrainableTensor


Expand Down Expand Up @@ -62,3 +66,78 @@ def test_forward_no_trainable(self, init, x):
batch_size = 5
output = trainable_tensor(x, batch_size)
assert output.shape == (batch_size * x.shape[0], tensor_size + trainable_size)


class TestNamedNodesAttributes:
"""Test suite for the NamedNodesAttributes class.
This class contains test cases to verify the functionality of the NamedNodesAttributes class,
including initialization, attribute registration, and forward pass operations.
"""

nodes_names: list[str] = ["nodes1", "nodes2"]
ndim: int = 2
num_trainable_params: int = 8

@pytest.fixture
def graph_data(self):
graph = HeteroData()
for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names):
graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i + 1))
return graph

@staticmethod
def get_n_random_coords(n: int) -> torch.Tensor:
coords = torch.rand(n, TestNamedNodesAttributes.ndim)
coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2)
coords[:, 1] = 2 * np.pi * coords[:, 1]
return coords

@pytest.fixture
def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes:
return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data)

def test_init(self, nodes_attributes):
assert isinstance(nodes_attributes, NamedNodesAttributes)

for nodes_name in self.nodes_names:
assert isinstance(nodes_attributes.num_nodes[nodes_name], int)
assert (
nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim
== TestNamedNodesAttributes.num_trainable_params
)
assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor)

def test_forward(self, nodes_attributes, graph_data):
batch_size = 3
for nodes_name in self.nodes_names:
output = nodes_attributes(nodes_name, batch_size)

expected_shape = (
batch_size * graph_data[nodes_name].num_nodes,
2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params,
)
assert output.shape == expected_shape

# Check if the first part of the output matches the sin-cos transformed coordinates
latlons = getattr(nodes_attributes, f"latlons_{nodes_name}")
repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size)
assert torch.allclose(output[:, : 2 * TestNamedNodesAttributes.ndim], repeated_latlons)

# Check if the last part of the output is trainable (requires grad)
assert output[:, 2 * TestNamedNodesAttributes.ndim :].requires_grad

def test_forward_no_trainable(self, graph_data):
no_trainable_attributes = NamedNodesAttributes(0, graph_data)
batch_size = 2

for nodes_name in self.nodes_names:
output = no_trainable_attributes(nodes_name, batch_size)

expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim
assert output.shape == expected_shape

# Check if the output exactly matches the sin-cos transformed coordinates
latlons = getattr(no_trainable_attributes, f"latlons_{nodes_name}")
repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size)
assert torch.allclose(output, repeated_latlons)

0 comments on commit 8874571

Please sign in to comment.