diff --git a/.gitignore b/.gitignore index d8baf06..d92403b 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.envrc .venv env/ venv/ diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 0967041..442676d 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -39,12 +39,13 @@ class BaseMapper(nn.Module, ABC): def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + activation: str, cpu_offload: bool = False, - activation: str = "SiLU", **kwargs, ) -> None: """Initialize BaseMapper.""" @@ -175,20 +176,21 @@ class GraphTransformerBaseMapper(GraphEdgeMixin, BaseMapper): def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, + num_heads: int, + mlp_hidden_ratio: int, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, cpu_offload: bool = False, - activation: str = "GELU", - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, ) -> None: """Initialize GraphTransformerBaseMapper. @@ -200,23 +202,33 @@ def __init__( Input channels of the destination node hidden_dim : int Hidden dimension + out_channels_dst : int + Output channels of the destination node trainable_size : int Trainable tensor of edge + num_chunks: int, optional + Message passing in chunks, by default 1 num_heads: int - Number of heads to use, default 16 + Number of heads to use mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - activation : str, optional - Activation function, by default "GELU" + Ratio of mlp hidden dimension to embedding dimension + activation : str + Activation function + sub_graph : HeteroData + Sub graph of the full structure + sub_graph_edge_attributes : list[str] + Edge attributes to use + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False - out_channels_dst : Optional[int], optional - Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, out_channels_dst=out_channels_dst, num_chunks=num_chunks, cpu_offload=cpu_offload, @@ -228,9 +240,9 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) self.proc = GraphTransformerMapperBlock( - hidden_dim, - mlp_hidden_ratio * hidden_dim, - hidden_dim, + in_channels=hidden_dim, + hidden_dim=mlp_hidden_ratio * hidden_dim, + out_channels=hidden_dim, num_heads=num_heads, edge_dim=self.edge_dim, activation=activation, @@ -276,20 +288,21 @@ class GraphTransformerForwardMapper(ForwardMapperPreProcessMixin, GraphTransform def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, + num_heads: int, + mlp_hidden_ratio: int, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, cpu_offload: bool = False, - activation: str = "GELU", - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, ) -> None: """Initialize GraphTransformerForwardMapper. @@ -301,24 +314,34 @@ def __init__( Input channels of the destination node hidden_dim : int Hidden dimension + out_channels_dst : int + Output channels of the destination node trainable_size : int Trainable tensor of edge + num_chunks: int, optional + Message passing in chunks, by default 1 num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int ratio of mlp hidden dimension to embedding dimension, default 4 activation : str, optional Activation function, by default "GELU" + sub_graph : HeteroData + Sub graph passed in by the model + sub_graph_edge_attributes : list[str] + Edge attributes to use + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False - out_channels_dst : Optional[int], optional - Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, - trainable_size, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, + trainable_size=trainable_size, out_channels_dst=out_channels_dst, num_chunks=num_chunks, cpu_offload=cpu_offload, @@ -349,20 +372,21 @@ class GraphTransformerBackwardMapper(BackwardMapperPostProcessMixin, GraphTransf def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, + num_heads: int, + mlp_hidden_ratio: int, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, cpu_offload: bool = False, - activation: str = "GELU", - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, ) -> None: """Initialize GraphTransformerBackwardMapper. @@ -374,24 +398,34 @@ def __init__( Input channels of the destination node hidden_dim : int Hidden dimension + out_channels_dst : int + Output channels of the destination node trainable_size : int Trainable tensor of edge + num_chunks: int, optional + Message passing in chunks, by default 1 num_heads: int - Number of heads to use, default 16 + Number of heads to use mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - activation : str, optional - Activation function, by default "GELU" + ratio of mlp hidden dimension to embedding dimension + activation : str + Activation function + sub_graph : HeteroData + Sub graph passed in by the model + sub_graph_edge_attributes : list[str] + Edge attributes to use + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False - out_channels_dst : Optional[int], optional - Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, - trainable_size, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, + trainable_size=trainable_size, out_channels_dst=out_channels_dst, num_chunks=num_chunks, cpu_offload=cpu_offload, @@ -422,19 +456,20 @@ class GNNBaseMapper(GraphEdgeMixin, BaseMapper): def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, - cpu_offload: bool = False, - activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, + cpu_offload: bool = False, ) -> None: """Initialize GNNBaseMapper. @@ -446,23 +481,31 @@ def __init__( Input channels of the destination node hidden_dim : int Hidden dimension + out_channels_dst : int + Output channels of the destination node trainable_size : int Trainable tensor of edge + num_chunks: int, optional + Message passing in chunks, by default 1 mlp_extra_layers : int, optional Number of extra layers in MLP, by default 0 activation : str, optional Activation function, by default "SiLU" - num_chunks : int - Do message passing in X chunks + sub_graph: HeteroData + Sub graph passed in by the model + sub_graph_edge_attributes : list[str] + Edge attributes to use + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False - out_channels_dst : Optional[int], optional - Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, out_channels_dst=out_channels_dst, num_chunks=num_chunks, cpu_offload=cpu_offload, @@ -526,19 +569,20 @@ class GNNForwardMapper(ForwardMapperPreProcessMixin, GNNBaseMapper): def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, - cpu_offload: bool = False, - activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, + cpu_offload: bool = False, ) -> None: """Initialize GNNForwardMapper. @@ -550,38 +594,46 @@ def __init__( Input channels of the destination node hidden_dim : int Hidden dimension - edge_dim : int + out_channels_dst : int + Output channels of the destination node + trainable_size : int Trainable tensor of edge - mlp_extra_layers : int, optional + num_chunks : int, optional + Do message passing in chunks, by default 1 + mlp_extra_layers : int Number of extra layers in MLP, by default 0 - activation : str, optional - Activation function, by default "SiLU" - num_chunks : int - Do message passing in X chunks + activation : str + Activation function + sub_graph : HeteroData + Sub graph passed in by the model + sub_graph_edge_attributes : list[str] + Edge attributes to use + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False - out_channels_dst : Optional[int], optional - Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, - trainable_size, - out_channels_dst, - num_chunks, - cpu_offload, - activation, - mlp_extra_layers, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, + out_channels_dst=out_channels_dst, + trainable_size=trainable_size, + num_chunks=num_chunks, + mlp_extra_layers=mlp_extra_layers, + activation=activation, sub_graph=sub_graph, sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + cpu_offload=cpu_offload, ) self.proc = GraphConvMapperBlock( - hidden_dim, - hidden_dim, + in_channels=hidden_dim, + out_channels=hidden_dim, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=True, @@ -612,19 +664,20 @@ class GNNBackwardMapper(BackwardMapperPostProcessMixin, GNNBaseMapper): def __init__( self, - in_channels_src: int = 0, - in_channels_dst: int = 0, - hidden_dim: int = 128, - trainable_size: int = 8, - out_channels_dst: Optional[int] = None, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: int, + trainable_size: int = 0, num_chunks: int = 1, - cpu_offload: bool = False, - activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, + cpu_offload: bool = False, ) -> None: """Initialize GNNBackwardMapper. @@ -650,10 +703,10 @@ def __init__( Output channels of the destination node, by default None """ super().__init__( - in_channels_src, - in_channels_dst, - hidden_dim, - trainable_size, + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, + trainable_size=trainable_size, out_channels_dst=out_channels_dst, num_chunks=num_chunks, cpu_offload=cpu_offload, @@ -666,8 +719,8 @@ def __init__( ) self.proc = GraphConvMapperBlock( - hidden_dim, - hidden_dim, + in_channels=hidden_dim, + out_channels=hidden_dim, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=False, diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index bb33609..cc36845 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -35,9 +35,9 @@ def __init__( self, num_layers: int, *args, - num_channels: int = 128, - num_chunks: int = 2, - activation: str = "GELU", + num_channels: int, + num_chunks: int, + activation: str, cpu_offload: bool = False, **kwargs, ) -> None: @@ -86,15 +86,15 @@ class TransformerProcessor(BaseProcessor): def __init__( self, - num_layers: int, *args, - window_size: Optional[int] = None, - num_channels: int = 128, - num_chunks: int = 2, - activation: str = "GELU", + num_layers: int, + num_channels: int, + num_chunks: int, + num_heads: int, + mlp_hidden_ratio: int, + activation: str, cpu_offload: bool = False, - num_heads: int = 16, - mlp_hidden_ratio: int = 4, + window_size: Optional[int] = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -103,16 +103,20 @@ def __init__( ---------- num_layers : int Number of num_layers - window_size: int, - 1/2 size of shifted window for attention computation num_channels : int number of channels - heads: int - Number of heads to use, default 16 + num_chunks : int + Number of chunks + num_heads: int + Number of heads to use mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - activation : str, optional - Activation function, by default "GELU" + ratio of mlp hidden dimension to embedding dimension + activation : str + Activation function + cpu_offload : bool + Whether to offload processing to CPU + window_size: int, optional + 1/2 size of shifted window for attention computation """ super().__init__( num_channels=num_channels, @@ -162,34 +166,44 @@ class GNNProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, - num_layers: int, *args, - trainable_size: int = 8, - num_channels: int = 128, - num_chunks: int = 2, + num_layers: int, + num_channels: int, + num_chunks: int, + trainable_size: int, mlp_extra_layers: int = 0, - activation: str = "SiLU", + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, cpu_offload: bool = False, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, **kwargs, ) -> None: """Initialize GNNProcessor. Parameters ---------- - num_channels : int - Number of Channels num_layers : int Number of layers - num_chunks : int, optional - Number of num_chunks, by default 2 + num_channels : int + Number of Channels + num_chunks : int + Number of num_chunks + trainable_size : int + Size of trainable tensor mlp_extra_layers : int, optional Number of extra layers in MLP, by default 0 - activation : str, optional - Activation function, by default "SiLU" + activation : str + Activation function + sub_graph : HeteroData + Sub graph passed in from model + sub_graph_edge_attributes : list[str] + List of edge attributes + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False """ @@ -250,18 +264,19 @@ class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, + *args, num_layers: int, - trainable_size: int = 8, - num_channels: int = 128, - num_chunks: int = 2, - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - activation: str = "GELU", + num_channels: int, + num_chunks: int, + num_heads: int, + trainable_size: int, + mlp_hidden_ratio: int, + activation: str, + sub_graph: HeteroData, + sub_graph_edge_attributes: list[str], + src_grid_size: int, + dst_grid_size: int, cpu_offload: bool = False, - sub_graph: Optional[HeteroData] = None, - sub_graph_edge_attributes: Optional[list[str]] = None, - src_grid_size: int = 0, - dst_grid_size: int = 0, **kwargs, ) -> None: """Initialize GraphTransformerProcessor. @@ -273,13 +288,23 @@ def __init__( num_channels : int Number of channels num_chunks : int, optional - Number of num_chunks, by default 2 - heads: int - Number of heads to use, default 16 + Number of num_chunks + num_heads: int + Number of heads to use + trainable_size : int + Size of trainable tensor mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - activation : str, optional - Activation function, by default "GELU" + ratio of mlp hidden dimension to embedding dimension + activation : str + Activation function + sub_graph : HeteroData + Sub graph passed in from model + sub_graph_edge_attributes : list[str] + List of edge attributes + src_grid_size : int + Source grid size + dst_grid_size : int + Destination grid size cpu_offload : bool, optional Whether to offload processing to CPU, by default False """ diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index 3cc4ef0..2060182 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -5,6 +5,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from torch_geometric.data import HeteroData @@ -12,6 +14,19 @@ from anemoi.models.layers.mapper import BaseMapper +@dataclass +class BaseMapperConfig: + sub_graph: HeteroData + sub_graph_edge_attributes: list[str] + in_channels_src: int = 3 + in_channels_dst: int = 3 + hidden_dim: int = 128 + out_channels_dst: int = 5 + cpu_offload: bool = False + activation: str = "SiLU" + trainable_size: int = 6 + + class TestBaseMapper: """Test the BaseMapper class.""" @@ -20,61 +35,31 @@ class TestBaseMapper: NUM_DST_NODES: int = 200 @pytest.fixture - def mapper_init(self): - in_channels_src: int = 3 - in_channels_dst: int = 3 - hidden_dim: int = 128 - out_channels_dst: int = 5 - cpu_offload: bool = False - activation: str = "SiLU" - trainable_size: int = 6 - return ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, + def mapper_init(self, fake_graph): + return BaseMapperConfig( + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], ) @pytest.fixture - def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - ) = mapper_init + def mapper(self, mapper_init): return BaseMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, - sub_graph=fake_graph[("src", "to", "dst")], - sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, + sub_graph=mapper_init.sub_graph, + sub_graph_edge_attributes=mapper_init.sub_graph_edge_attributes, + trainable_size=mapper_init.trainable_size, ) @pytest.fixture def pair_tensor(self, mapper_init): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init return ( - torch.rand(in_channels_src, hidden_dim), - torch.rand(in_channels_dst, hidden_dim), + torch.rand(mapper_init.in_channels_src, mapper_init.hidden_dim), + torch.rand(mapper_init.in_channels_dst, mapper_init.hidden_dim), ) @pytest.fixture @@ -93,21 +78,12 @@ def fake_graph(self) -> HeteroData: return graph def test_initialization(self, mapper, mapper_init): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - activation, - _trainable_size, - ) = mapper_init assert isinstance(mapper, BaseMapper) - assert mapper.in_channels_src == in_channels_src - assert mapper.in_channels_dst == in_channels_dst - assert mapper.hidden_dim == hidden_dim - assert mapper.out_channels_dst == out_channels_dst - assert mapper.activation == activation + assert mapper.in_channels_src == mapper_init.in_channels_src + assert mapper.in_channels_dst == mapper_init.in_channels_dst + assert mapper.hidden_dim == mapper_init.hidden_dim + assert mapper.out_channels_dst == mapper_init.out_channels_dst + assert mapper.activation == mapper_init.activation def test_pre_process(self, mapper, pair_tensor): x = pair_tensor diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 4be7130..c8df6c5 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -5,6 +5,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from torch import nn @@ -15,6 +17,19 @@ from anemoi.models.layers.mapper import GNNForwardMapper +@dataclass +class MapperConfig: + dst_grid_size: int + src_grid_size: int + in_channels_src: int = 3 + in_channels_dst: int = 3 + hidden_dim: int = 256 + out_channels_dst: int = 5 + cpu_offload: bool = False + activation: str = "SiLU" + trainable_size: int = 6 + + class TestGNNBaseMapper: """Test the GNNBaseMapper class.""" @@ -24,60 +39,29 @@ class TestGNNBaseMapper: @pytest.fixture def mapper_init(self): - in_channels_src: int = 3 - in_channels_dst: int = 4 - hidden_dim: int = 256 - out_channels_dst: int = 8 - cpu_offload: bool = False - activation: str = "SiLU" - trainable_size: int = 6 - return ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - ) + return MapperConfig(src_grid_size=self.NUM_SRC_NODES, dst_grid_size=self.NUM_DST_NODES) @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - ) = mapper_init return GNNBaseMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, + trainable_size=mapper_init.trainable_size, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) @pytest.fixture def pair_tensor(self, mapper_init): - ( - in_channels_src, - in_channels_dst, - _hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init return ( - torch.rand(self.NUM_SRC_NODES, in_channels_src), - torch.rand(self.NUM_DST_NODES, in_channels_dst), + torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src), + torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst), ) @pytest.fixture @@ -96,34 +80,16 @@ def fake_graph(self) -> HeteroData: return graph def test_initialization(self, mapper, mapper_init): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - activation, - _trainable_size, - ) = mapper_init assert isinstance(mapper, GNNBaseMapper) - assert mapper.in_channels_src == in_channels_src - assert mapper.in_channels_dst == in_channels_dst - assert mapper.hidden_dim == hidden_dim - assert mapper.out_channels_dst == out_channels_dst - assert mapper.activation == activation + assert mapper.in_channels_src == mapper_init.in_channels_src + assert mapper.in_channels_dst == mapper_init.in_channels_dst + assert mapper.hidden_dim == mapper_init.hidden_dim + assert mapper.out_channels_dst == mapper_init.out_channels_dst + assert mapper.activation == mapper_init.activation - def test_pre_process(self, mapper, mapper_init, pair_tensor): + def test_pre_process(self, mapper, pair_tensor): # Should be a no-op in the base class x = pair_tensor - ( - _in_channels_src, - _in_channels_dst, - _hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -154,38 +120,23 @@ class TestGNNForwardMapper(TestGNNBaseMapper): @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - ) = mapper_init return GNNForwardMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, + trainable_size=mapper_init.trainable_size, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): x = pair_tensor - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -201,17 +152,9 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_forward_backward(self, mapper_init, mapper, pair_tensor): - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init x = pair_tensor batch_size = 1 + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst = mapper.forward(x, batch_size, shard_shapes) @@ -245,38 +188,25 @@ class TestGNNBackwardMapper(TestGNNBaseMapper): @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - ) = mapper_init return GNNBackwardMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, + trainable_size=mapper_init.trainable_size, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): x = pair_tensor - ( - in_channels_src, - in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init + in_channels_src = mapper_init.in_channels_src + in_channels_dst = mapper_init.in_channels_dst + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -292,15 +222,8 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_post_process(self, mapper, mapper_init): - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init + hidden_dim = mapper_init.hidden_dim + out_channels_dst = mapper_init.out_channels_dst x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] @@ -310,15 +233,8 @@ def test_post_process(self, mapper, mapper_init): ), f"[self.NUM_DST_NODES, out_channels_dst] ({[self.NUM_DST_NODES, out_channels_dst]}) != result.shape ({result.shape})" def test_forward_backward(self, mapper_init, mapper, pair_tensor): - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - ) = mapper_init + hidden_dim = mapper_init.hidden_dim + out_channels_dst = mapper_init.out_channels_dst pair_tensor shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)] batch_size = 1 diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index c872422..84473bc 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -5,6 +5,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from torch import nn @@ -15,6 +17,21 @@ from anemoi.models.layers.mapper import GraphTransformerForwardMapper +@dataclass +class MapperConfig: + in_channels_src: int = 3 + in_channels_dst: int = 3 + hidden_dim: int = 256 + out_channels_dst: int = 5 + cpu_offload: bool = False + activation: str = "SiLU" + trainable_size: int = 6 + num_heads: int = 16 + mlp_hidden_ratio: int = 7 + src_grid_size: int = 3 + dst_grid_size: int = 9 + + class TestGraphTransformerBaseMapper: """Test the GraphTransformerBaseMapper class.""" @@ -24,70 +41,31 @@ class TestGraphTransformerBaseMapper: @pytest.fixture def mapper_init(self): - in_channels_src: int = 3 - in_channels_dst: int = 3 - hidden_dim: int = 256 - out_channels_dst: int = 5 - cpu_offload: bool = False - activation: str = "SiLU" - trainable_size: int = 6 - num_heads: int = 16 - mlp_hidden_ratio: int = 7 - return ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - num_heads, - mlp_hidden_ratio, - ) + return MapperConfig(src_grid_size=self.NUM_SRC_NODES, dst_grid_size=self.NUM_DST_NODES) @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - num_heads, - mlp_hidden_ratio, - ) = mapper_init return GraphTransformerBaseMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, + trainable_size=mapper_init.trainable_size, + num_heads=mapper_init.num_heads, + mlp_hidden_ratio=mapper_init.mlp_hidden_ratio, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) @pytest.fixture def pair_tensor(self, mapper_init): - ( - in_channels_src, - in_channels_dst, - _hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init return ( - torch.rand(self.NUM_SRC_NODES, in_channels_src), - torch.rand(self.NUM_DST_NODES, in_channels_dst), + torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src), + torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst), ) @pytest.fixture @@ -106,38 +84,17 @@ def fake_graph(self) -> HeteroData: return graph def test_initialization(self, mapper, mapper_init): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init assert isinstance(mapper, GraphTransformerBaseMapper) - assert mapper.in_channels_src == in_channels_src - assert mapper.in_channels_dst == in_channels_dst - assert mapper.hidden_dim == hidden_dim - assert mapper.out_channels_dst == out_channels_dst - assert mapper.activation == activation + assert mapper.in_channels_src == mapper_init.in_channels_src + assert mapper.in_channels_dst == mapper_init.in_channels_dst + assert mapper.hidden_dim == mapper_init.hidden_dim + assert mapper.out_channels_dst == mapper_init.out_channels_dst + assert mapper.activation == mapper_init.activation def test_pre_process(self, mapper, mapper_init, pair_tensor): + del mapper_init # Should be a no-op in the base class x = pair_tensor - ( - _in_channels_src, - _in_channels_dst, - _hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -168,44 +125,25 @@ class TestGraphTransformerForwardMapper(TestGraphTransformerBaseMapper): @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - num_heads, - mlp_hidden_ratio, - ) = mapper_init return GraphTransformerForwardMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, + trainable_size=mapper_init.trainable_size, + num_heads=mapper_init.num_heads, + mlp_hidden_ratio=mapper_init.mlp_hidden_ratio, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): x = pair_tensor - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -221,19 +159,11 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_forward_backward(self, mapper_init, mapper, pair_tensor): - ( - in_channels_src, - _in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init + x = pair_tensor batch_size = 1 + in_channels_src = mapper_init.in_channels_src + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst = mapper.forward(x, batch_size, shard_shapes) @@ -267,42 +197,26 @@ class TestGraphTransformerBackwardMapper(TestGraphTransformerBaseMapper): @pytest.fixture def mapper(self, mapper_init, fake_graph): - ( - in_channels_src, - in_channels_dst, - hidden_dim, - out_channels_dst, - cpu_offload, - activation, - trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init return GraphTransformerBackwardMapper( - in_channels_src=in_channels_src, - in_channels_dst=in_channels_dst, - hidden_dim=hidden_dim, - out_channels_dst=out_channels_dst, - cpu_offload=cpu_offload, - activation=activation, + in_channels_src=mapper_init.in_channels_src, + in_channels_dst=mapper_init.in_channels_dst, + hidden_dim=mapper_init.hidden_dim, + out_channels_dst=mapper_init.out_channels_dst, + cpu_offload=mapper_init.cpu_offload, + activation=mapper_init.activation, sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], - trainable_size=trainable_size, + trainable_size=mapper_init.trainable_size, + num_heads=mapper_init.num_heads, + mlp_hidden_ratio=mapper_init.mlp_hidden_ratio, + src_grid_size=mapper_init.src_grid_size, + dst_grid_size=mapper_init.dst_grid_size, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): x = pair_tensor - ( - in_channels_src, - _in_channels_dst, - hidden_dim, - _out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init + in_channels_src = mapper_init.in_channels_src + hidden_dim = mapper_init.hidden_dim shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) @@ -318,17 +232,8 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_post_process(self, mapper, mapper_init): - ( - _in_channels_src, - _in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init + hidden_dim = mapper_init.hidden_dim + out_channels_dst = mapper_init.out_channels_dst x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] @@ -338,18 +243,10 @@ def test_post_process(self, mapper, mapper_init): ), f"[self.NUM_DST_NODES, out_channels_dst] ({[self.NUM_DST_NODES, out_channels_dst]}) != result.shape ({result.shape})" def test_forward_backward(self, mapper_init, mapper, pair_tensor): - ( - in_channels_src, - _in_channels_dst, - hidden_dim, - out_channels_dst, - _cpu_offload, - _activation, - _trainable_size, - _num_heads, - _mlp_hidden_ratio, - ) = mapper_init - pair_tensor + in_channels_src = mapper_init.in_channels_src + hidden_dim = mapper_init.hidden_dim + out_channels_dst = mapper_init.out_channels_dst + shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)] batch_size = 1 diff --git a/tests/layers/processor/test_base_processor.py b/tests/layers/processor/test_base_processor.py index 4af3c7b..1699743 100644 --- a/tests/layers/processor/test_base_processor.py +++ b/tests/layers/processor/test_base_processor.py @@ -5,35 +5,44 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest from anemoi.models.layers.processor import BaseProcessor +@dataclass +class ProcessorInit: + num_layers: int = 4 + num_channels: int = 128 + num_chunks: int = 2 + activation: str = "GELU" + cpu_offload: bool = False + + @pytest.fixture def processor_init(): - num_layers = 4 - num_channels = 128 - num_chunks = 2 - activation = "GELU" - cpu_offload = False - return num_layers, num_channels, num_chunks, activation, cpu_offload + return ProcessorInit() @pytest.fixture() def base_processor(processor_init): - num_layers, num_channels, num_chunks, activation, cpu_offload = processor_init return BaseProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - activation=activation, - cpu_offload=cpu_offload, + num_layers=processor_init.num_layers, + num_channels=processor_init.num_channels, + num_chunks=processor_init.num_chunks, + activation=processor_init.activation, + cpu_offload=processor_init.cpu_offload, ) def test_base_processor_init(processor_init, base_processor): - num_layers, num_channels, num_chunks, *_ = processor_init + num_layers, num_channels, num_chunks = ( + processor_init.num_layers, + processor_init.num_channels, + processor_init.num_chunks, + ) assert isinstance(base_processor.num_chunks, int), "num_layers should be an integer" assert isinstance(base_processor.num_channels, int), "num_channels should be an integer" diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index 2505515..2fac3d7 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -5,6 +5,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from torch_geometric.data import HeteroData @@ -13,6 +15,21 @@ from anemoi.models.layers.processor import GNNProcessor +@dataclass +class GNNProcessorInit: + sub_graph: HeteroData + edge_attributes: list[str] + num_layers: int = 2 + num_channels: int = 128 + num_chunks: int = 2 + mlp_extra_layers: int = 0 + activation: str = "SiLU" + cpu_offload: bool = False + src_grid_size: int = 13 + dst_grid_size: int = 7 + trainable_size: int = 8 + + class TestGNNProcessor: """Test the GNNProcessor class.""" @@ -30,94 +47,37 @@ def fake_graph(self) -> tuple[HeteroData, int]: @pytest.fixture def graphconv_init(self, fake_graph: HeteroData): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - mlp_extra_layers = 0 - activation = "SiLU" - cpu_offload = False - sub_graph = fake_graph[("nodes", "to", "nodes")] - edge_attributes = ["edge_attr1", "edge_attr2"] - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 8 - return ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, + return GNNProcessorInit( + sub_graph=fake_graph[("nodes", "to", "nodes")], edge_attributes=["edge_attr1", "edge_attr2"] ) @pytest.fixture def graphconv_processor(self, graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphconv_init return GNNProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - mlp_extra_layers=mlp_extra_layers, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - sub_graph_edge_attributes=edge_attributes, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, + num_layers=graphconv_init.num_layers, + num_channels=graphconv_init.num_channels, + num_chunks=graphconv_init.num_chunks, + mlp_extra_layers=graphconv_init.mlp_extra_layers, + activation=graphconv_init.activation, + cpu_offload=graphconv_init.cpu_offload, + sub_graph=graphconv_init.sub_graph, + sub_graph_edge_attributes=graphconv_init.edge_attributes, + src_grid_size=graphconv_init.src_grid_size, + dst_grid_size=graphconv_init.dst_grid_size, + trainable_size=graphconv_init.trainable_size, ) def test_graphconv_processor_init(self, graphconv_processor, graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphconv_init - assert graphconv_processor.num_chunks == num_chunks - assert graphconv_processor.num_channels == num_channels - assert graphconv_processor.chunk_size == num_layers // num_chunks + assert graphconv_processor.num_chunks == graphconv_init.num_chunks + assert graphconv_processor.num_channels == graphconv_init.num_channels + assert graphconv_processor.chunk_size == graphconv_init.num_layers // graphconv_init.num_chunks assert isinstance(graphconv_processor.trainable, TrainableTensor) def test_forward(self, graphconv_processor, graphconv_init): batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphconv_init + num_channels = graphconv_init.num_channels + trainable_size = graphconv_init.trainable_size + x = torch.rand((self.NUM_EDGES, num_channels)) shard_shapes = [list(x.shape)] diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index dfba417..dd43403 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -5,6 +5,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from torch_geometric.data import HeteroData @@ -13,6 +15,22 @@ from anemoi.models.layers.processor import GraphTransformerProcessor +@dataclass +class GraphTransformerProcessorConfig: + sub_graph: HeteroData + edge_attributes: list[str] + num_layers: int = 2 + num_channels: int = 128 + num_chunks: int = 2 + num_heads: int = 16 + mlp_hidden_ratio: int = 4 + activation: str = "GELU" + cpu_offload: bool = False + src_grid_size: int = 7 + dst_grid_size: int = 13 + trainable_size: int = 6 + + class TestGraphTransformerProcessor: """Test the GraphTransformerProcessor class.""" @@ -30,79 +48,31 @@ def fake_graph(self) -> tuple[HeteroData, int]: @pytest.fixture def graphtransformer_init(self, fake_graph: HeteroData): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - num_heads = 16 - mlp_hidden_ratio = 4 - activation = "GELU" - cpu_offload = False - sub_graph = fake_graph[("nodes", "to", "nodes")] - edge_attributes = ["edge_attr1", "edge_attr2"] - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 6 - return ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, + return GraphTransformerProcessorConfig( + sub_graph=fake_graph[("nodes", "to", "nodes")], edge_attributes=["edge_attr1", "edge_attr2"] ) @pytest.fixture def graphtransformer_processor(self, graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphtransformer_init return GraphTransformerProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - sub_graph_edge_attributes=edge_attributes, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, + num_layers=graphtransformer_init.num_layers, + num_channels=graphtransformer_init.num_channels, + num_chunks=graphtransformer_init.num_chunks, + num_heads=graphtransformer_init.num_heads, + mlp_hidden_ratio=graphtransformer_init.mlp_hidden_ratio, + activation=graphtransformer_init.activation, + cpu_offload=graphtransformer_init.cpu_offload, + sub_graph=graphtransformer_init.sub_graph, + sub_graph_edge_attributes=graphtransformer_init.edge_attributes, + src_grid_size=graphtransformer_init.src_grid_size, + dst_grid_size=graphtransformer_init.dst_grid_size, + trainable_size=graphtransformer_init.trainable_size, ) def test_graphtransformer_processor_init(self, graphtransformer_processor, graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphtransformer_init + num_layers = graphtransformer_init.num_layers + num_channels = graphtransformer_init.num_channels + num_chunks = graphtransformer_init.num_chunks assert graphtransformer_processor.num_chunks == num_chunks assert graphtransformer_processor.num_channels == num_channels assert graphtransformer_processor.chunk_size == num_layers // num_chunks @@ -110,20 +80,8 @@ def test_graphtransformer_processor_init(self, graphtransformer_processor, graph def test_forward(self, graphtransformer_processor, graphtransformer_init): batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphtransformer_init + num_channels = graphtransformer_init.num_channels + trainable_size = graphtransformer_init.trainable_size x = torch.rand((self.NUM_EDGES, num_channels)) shard_shapes = [list(x.shape)] diff --git a/tests/layers/processor/test_transformer_processor.py b/tests/layers/processor/test_transformer_processor.py index d359c27..6d6dd03 100644 --- a/tests/layers/processor/test_transformer_processor.py +++ b/tests/layers/processor/test_transformer_processor.py @@ -5,69 +5,49 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from dataclasses import dataclass + import pytest import torch from anemoi.models.layers.processor import TransformerProcessor +@dataclass +class TransformerProcessorConfig: + num_layers: int = 2 + window_size: int = 10 + num_channels: int = 128 + num_chunks: int = 2 + activation: str = "GELU" + cpu_offload: bool = False + num_heads: int = 16 + mlp_hidden_ratio: int = 4 + + @pytest.fixture def transformer_processor_init(): - num_layers = 2 - window_size = 10 - num_channels = 128 - num_chunks = 2 - activation = "GELU" - cpu_offload = False - num_heads = 16 - mlp_hidden_ratio = 4 - return ( - num_layers, - window_size, - num_channels, - num_chunks, - activation, - cpu_offload, - num_heads, - mlp_hidden_ratio, - ) + return TransformerProcessorConfig() @pytest.fixture def transformer_processor(transformer_processor_init): - ( - num_layers, - window_size, - num_channels, - num_chunks, - activation, - cpu_offload, - num_heads, - mlp_hidden_ratio, - ) = transformer_processor_init return TransformerProcessor( - num_layers=num_layers, - window_size=window_size, - num_channels=num_channels, - num_chunks=num_chunks, - activation=activation, - cpu_offload=cpu_offload, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, + num_layers=transformer_processor_init.num_layers, + window_size=transformer_processor_init.window_size, + num_channels=transformer_processor_init.num_channels, + num_chunks=transformer_processor_init.num_chunks, + activation=transformer_processor_init.activation, + cpu_offload=transformer_processor_init.cpu_offload, + num_heads=transformer_processor_init.num_heads, + mlp_hidden_ratio=transformer_processor_init.mlp_hidden_ratio, ) def test_transformer_processor_init(transformer_processor, transformer_processor_init): - ( - num_layers, - _window_size, - num_channels, - num_chunks, - _activation, - _cpu_offload, - _num_heads, - _mlp_hidden_ratio, - ) = transformer_processor_init + num_layers = transformer_processor_init.num_layers + num_channels = transformer_processor_init.num_channels + num_chunks = transformer_processor_init.num_chunks assert isinstance(transformer_processor, TransformerProcessor) assert transformer_processor.num_chunks == num_chunks assert transformer_processor.num_channels == num_channels @@ -75,16 +55,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor def test_transformer_processor_forward(transformer_processor, transformer_processor_init): - ( - _num_layers, - _window_size, - num_channels, - _num_chunks, - _activation, - _cpu_offload, - _num_heads, - _mlp_hidden_ratio, - ) = transformer_processor_init + num_channels = transformer_processor_init.num_channels gridsize = 100 batch_size = 1 x = torch.rand(gridsize, num_channels) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index e47a0b9..9d0cb25 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -5,78 +5,105 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import pytest +from dataclasses import dataclass + +import hypothesis.strategies as st import torch +from hypothesis import given +from hypothesis import settings from anemoi.models.layers.mlp import MLP -@pytest.fixture -def batch_size(): - return 1 - - -@pytest.fixture -def nlatlon(): - return 1024 - - -@pytest.fixture -def num_features(): - return 64 - - -@pytest.fixture -def hdim(): - return 128 - - -@pytest.fixture -def num_out_feature(): - return 36 +@dataclass +class MLPConfig: + in_features: int = 48 + hidden_dim: int = 23 + out_features: int = 27 + n_extra_layers: int = 0 + activation: str = "SiLU" + final_activation: bool = False + layer_norm: bool = False + checkpoints: bool = False + + +from_config = dict( + init_config=st.builds( + MLPConfig, + in_features=st.integers(min_value=1, max_value=100), + hidden_dim=st.integers(min_value=1, max_value=100), + out_features=st.integers(min_value=1, max_value=100), + n_extra_layers=st.integers(min_value=0, max_value=10), + activation=st.sampled_from(("ReLU", "SiLU", "GELU")), + final_activation=st.booleans(), + layer_norm=st.booleans(), + checkpoints=st.booleans(), + ) +) + +run_model = dict( + batch_size=st.integers(min_value=1, max_value=2), + num_gridpoints=st.integers(min_value=1, max_value=512), +) class TestMLP: - def test_init(self, num_features, hdim, num_out_feature): - """Test MLP initialization.""" - mlp = MLP(num_features, hdim, num_out_feature, 0, "SiLU") - assert isinstance(mlp, MLP) - assert isinstance(mlp.model, torch.nn.Sequential) - assert len(mlp.model) == 6 - mlp = MLP(num_features, hdim, num_out_feature, 0, "ReLU", False, False, False) - assert len(mlp.model) == 5 + def create_model(self, init_config): + return MLP( + init_config.in_features, + init_config.hidden_dim, + init_config.out_features, + init_config.n_extra_layers, + init_config.activation, + init_config.final_activation, + init_config.layer_norm, + init_config.checkpoints, + ) + + @given(**from_config) + def test_init(self, init_config): + """Test MLP initialization.""" + mlp = self.create_model(init_config) - mlp = MLP(num_features, hdim, num_out_feature, 1, "SiLU", False, False, False) - assert len(mlp.model) == 7 + assert isinstance(mlp, MLP) + if isinstance(mlp.model, torch.nn.Sequential): + length = 3 + 2 * (init_config.n_extra_layers + 1) + init_config.layer_norm + init_config.final_activation + assert len(mlp.model) == length - def test_forwards(self, batch_size, nlatlon, num_features, hdim, num_out_feature): + @settings(deadline=None) + @given(**run_model, **from_config) + def test_forwards(self, batch_size, num_gridpoints, init_config): """Test MLP forward pass.""" - - mlp = MLP(num_features, hdim, num_out_feature, layer_norm=True) - x_in = torch.randn((batch_size, nlatlon, num_features), dtype=torch.float32, requires_grad=True) + mlp = self.create_model(init_config) + num_features = init_config.in_features + num_out_feature = init_config.out_features + x_in = torch.randn((batch_size, num_gridpoints, num_features), dtype=torch.float32, requires_grad=True) out = mlp(x_in) assert out.shape == ( batch_size, - nlatlon, + num_gridpoints, num_out_feature, ), "Output shape is not correct" - def test_backward(self, batch_size, nlatlon, num_features, hdim): + @given(**run_model, **from_config) + def test_backward(self, batch_size, num_gridpoints, init_config): """Test MLP backward pass.""" + mlp = self.create_model(init_config) + num_features = init_config.in_features + num_out_feature = init_config.out_features - x_in = torch.randn((batch_size, nlatlon, num_features), dtype=torch.float32, requires_grad=True) - mlp_1 = MLP(num_features, hdim, hdim, layer_norm=True) + x_in = torch.randn((batch_size, num_gridpoints, num_features), dtype=torch.float32, requires_grad=True) - y = mlp_1(x_in) - assert y.shape == (batch_size, nlatlon, hdim) + y = mlp(x_in) + assert y.shape == (batch_size, num_gridpoints, num_out_feature) loss = y.sum() print("running backward on the dummy loss ...") loss.backward() - for param in mlp_1.parameters(): + for param in mlp.parameters(): assert param.grad is not None, f"param.grad is None for {param}" assert ( param.grad.shape == param.shape