diff --git a/src/anemoi/models/distributed/khop_edges.py b/src/anemoi/models/distributed/khop_edges.py index 5b4bd81..7ce0d3a 100644 --- a/src/anemoi/models/distributed/khop_edges.py +++ b/src/anemoi/models/distributed/khop_edges.py @@ -46,7 +46,7 @@ def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: return edge_attr[mask_to_index(edge_mask_k)], edge_index_k -def sort_edges_1hop( +def sort_edges_1hop_sharding( num_nodes: Union[int, tuple[int, int]], edge_attr: Tensor, edge_index: Adj, @@ -74,29 +74,42 @@ def sort_edges_1hop( if mgroup: num_chunks = dist.get_world_size(group=mgroup) - if isinstance(num_nodes, int): - node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks) - else: - nodes_src = torch.arange(num_nodes[0], device=edge_index.device) - node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks) - - edge_index_list = [] - edge_attr_list = [] - for node_chunk in node_chunks: - if isinstance(num_nodes, int): - edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index) - else: - edge_index_chunk, edge_attr_chunk = bipartite_subgraph( - (nodes_src, node_chunk), - edge_index, - edge_attr, - size=(num_nodes[0], num_nodes[1]), - ) - edge_index_list.append(edge_index_chunk) - edge_attr_list.append(edge_attr_chunk) + edge_attr_list, edge_index_list = sorted_1hop_chunks( + num_nodes=num_nodes, + edge_attr=edge_attr, + edge_index=edge_index, + num_chunks=num_chunks + ) + edge_index_shapes = [x.shape for x in edge_index_list] edge_attr_shapes = [x.shape for x in edge_attr_list] return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes return edge_attr, edge_index, [], [] + +def sorted_1hop_chunks( + num_nodes: Union[int, tuple[int, int]], edge_attr: Tensor, edge_index: Adj, num_chunks: int +) -> tuple[list[Tensor], list[Adj]]: + if isinstance(num_nodes, int): + node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks) + else: + nodes_src = torch.arange(num_nodes[0], device=edge_index.device) + node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks) + + edge_index_list = [] + edge_attr_list = [] + for node_chunk in node_chunks: + if isinstance(num_nodes, int): + edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index) + else: + edge_index_chunk, edge_attr_chunk = bipartite_subgraph( + (nodes_src, node_chunk), + edge_index, + edge_attr, + size=(num_nodes[0], num_nodes[1]), + ) + edge_index_list.append(edge_index_chunk) + edge_attr_list.append(edge_attr_chunk) + + return edge_attr_list, edge_index_list \ No newline at end of file diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index c43d059..564a9bc 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # +import os import logging from abc import ABC from abc import abstractmethod @@ -23,6 +24,7 @@ from anemoi.models.distributed.graph import shard_tensor from anemoi.models.distributed.graph import sync_tensor +from anemoi.models.distributed.khop_edges import sorted_1hop_chunks from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence from anemoi.models.layers.attention import MultiHeadSelfAttention @@ -32,6 +34,7 @@ LOGGER = logging.getLogger(__name__) +NUM_CHUNKS_INFERENCE = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS", "1")) class BaseBlock(nn.Module, ABC): """Base class for network blocks.""" @@ -492,11 +495,17 @@ def forward( query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) # TODO: remove magic number - num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference + num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE # reduce memory for inference if num_chunks > 1: - edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) - edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0) + #edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) + #edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0) + edge_attr_list, edge_index_list = sorted_1hop_chunks( + num_nodes=size, edge_attr=edges, edge_index=edge_index, + num_chunks=num_chunks + ) + + out = torch.zeros_like(query, dtype = query.dtype) for i in range(num_chunks): out1 = self.conv( query=query, @@ -506,20 +515,44 @@ def forward( edge_index=edge_index_list[i], size=size, ) - if i == 0: + """if i == 0: out = torch.zeros_like(out1) - out = out + out1 + out = out + out1""" + out.add_(out1) + else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) - out = self.projection(out + x_r) + #out = self.projection(out + x_r) + + projected = [] + for chunk in torch.tensor_split(out + x_r, num_chunks, dim = 0): + projected.append(self.projection(chunk)) + del chunk + + out = torch.cat(projected, dim = 0) + out = out + x_skip[1] - nodes_new_dst = self.node_dst_mlp(out) + out - nodes_new_src = self.node_src_mlp(x_skip[0]) + x_skip[0] if self.update_src_nodes else x_skip[0] + nodes_new_dst = [] + for chunk in out.tensor_split(num_chunks, dim = 0 ): + nodes_new_dst.append(self.node_dst_mlp(chunk) + chunk) + del chunk + nodes_new_dst = torch.cat(nodes_new_dst, dim = 0) + + if self.update_src_nodes: + nodes_new_src = [] + for chunk in x_skip[0].tensor_split(num_chunks, dim=0): + nodes_new_src.append(self.node_src_mlp(chunk) + chunk) + del chunk + + nodes_new_src = torch.cat(nodes_new_src, dim = 0) + + else: + nodes_new_src = x_skip[0] nodes_new = (nodes_new_src, nodes_new_dst) return nodes_new, edge_attr @@ -602,11 +635,13 @@ def forward( query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) # TODO: Is this alright? - num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference + num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE # reduce memory for inference if num_chunks > 1: edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0) + + out = torch.zeros_like(query, dtype = query.dtype) for i in range(num_chunks): out1 = self.conv( query=query, @@ -616,16 +651,30 @@ def forward( edge_index=edge_index_list[i], size=size, ) - if i == 0: - out = torch.zeros_like(out1) - out = out + out1 + #if i == 0: + # out = torch.zeros_like(out1) + #out = out + out1 + out.add_(out1) else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) + + """projected_chunks = [] + for chunk in torch.tensor_split(out + x_r, num_chunks ,dim = 0): + projected_chunks.append(self.projection(chunk)) + del chunk + + out = torch.cat(projected_chunks, dim = 0)""" out = self.projection(out + x_r) out = out + x_skip nodes_new = self.node_dst_mlp(out) + out + """nodes_new_chunks = [] + for chunk in torch.tensor_split(out, num_chunks, dim = 0): + nodes_new_chunks.append(self.node_dst_mlp(chunk) + chunk) + del chunk + + nodes_new = torch.cat(nodes_new_chunks, dim = 0)""" return nodes_new, edge_attr diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 4aaeb10..0adcb11 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -21,7 +21,7 @@ from anemoi.models.distributed.graph import gather_tensor from anemoi.models.distributed.graph import shard_tensor -from anemoi.models.distributed.khop_edges import sort_edges_1hop +from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding from anemoi.models.distributed.shapes import change_channels_in_shape from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.block import GraphConvMapperBlock @@ -465,7 +465,7 @@ def __init__( def prepare_edges(self, size, batch_size, model_comm_group=None): edge_attr = self.trainable(self.edge_attr, batch_size) edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size) - edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop( + edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop_sharding( size, edge_attr, edge_index, model_comm_group ) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index f1e91f6..740bd53 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -17,7 +17,7 @@ from torch.utils.checkpoint import checkpoint from anemoi.models.distributed.graph import shard_tensor -from anemoi.models.distributed.khop_edges import sort_edges_1hop +from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding from anemoi.models.distributed.shapes import change_channels_in_shape from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.chunk import GNNProcessorChunk @@ -229,7 +229,7 @@ def forward( edge_attr = self.trainable(self.edge_attr, batch_size) edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size) target_nodes = sum(x[0] for x in shape_nodes) - edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop( + edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop_sharding( target_nodes, edge_attr, edge_index,