Skip to content

Commit

Permalink
Implementation for chunking and some optimalization regarding tensor …
Browse files Browse the repository at this point in the history
…operation and freeing up memory.
  • Loading branch information
einrone committed Nov 19, 2024
1 parent e566510 commit 7eca6c8
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 37 deletions.
55 changes: 34 additions & 21 deletions src/anemoi/models/distributed/khop_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops:
return edge_attr[mask_to_index(edge_mask_k)], edge_index_k


def sort_edges_1hop(
def sort_edges_1hop_sharding(
num_nodes: Union[int, tuple[int, int]],
edge_attr: Tensor,
edge_index: Adj,
Expand Down Expand Up @@ -74,29 +74,42 @@ def sort_edges_1hop(
if mgroup:
num_chunks = dist.get_world_size(group=mgroup)

if isinstance(num_nodes, int):
node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks)
else:
nodes_src = torch.arange(num_nodes[0], device=edge_index.device)
node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks)

edge_index_list = []
edge_attr_list = []
for node_chunk in node_chunks:
if isinstance(num_nodes, int):
edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index)
else:
edge_index_chunk, edge_attr_chunk = bipartite_subgraph(
(nodes_src, node_chunk),
edge_index,
edge_attr,
size=(num_nodes[0], num_nodes[1]),
)
edge_index_list.append(edge_index_chunk)
edge_attr_list.append(edge_attr_chunk)
edge_attr_list, edge_index_list = sorted_1hop_chunks(
num_nodes=num_nodes,
edge_attr=edge_attr,
edge_index=edge_index,
num_chunks=num_chunks
)

edge_index_shapes = [x.shape for x in edge_index_list]
edge_attr_shapes = [x.shape for x in edge_attr_list]

return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes

return edge_attr, edge_index, [], []

def sorted_1hop_chunks(
num_nodes: Union[int, tuple[int, int]], edge_attr: Tensor, edge_index: Adj, num_chunks: int
) -> tuple[list[Tensor], list[Adj]]:
if isinstance(num_nodes, int):
node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks)
else:
nodes_src = torch.arange(num_nodes[0], device=edge_index.device)
node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks)

edge_index_list = []
edge_attr_list = []
for node_chunk in node_chunks:
if isinstance(num_nodes, int):
edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index)
else:
edge_index_chunk, edge_attr_chunk = bipartite_subgraph(
(nodes_src, node_chunk),
edge_index,
edge_attr,
size=(num_nodes[0], num_nodes[1]),
)
edge_index_list.append(edge_index_chunk)
edge_attr_list.append(edge_attr_chunk)

return edge_attr_list, edge_index_list
73 changes: 61 additions & 12 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# nor does it submit to any jurisdiction.
#

import os
import logging
from abc import ABC
from abc import abstractmethod
Expand All @@ -23,6 +24,7 @@

from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.graph import sync_tensor
from anemoi.models.distributed.khop_edges import sorted_1hop_chunks
from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.models.layers.attention import MultiHeadSelfAttention
Expand All @@ -32,6 +34,7 @@

LOGGER = logging.getLogger(__name__)

NUM_CHUNKS_INFERENCE = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS", "1"))

class BaseBlock(nn.Module, ABC):
"""Base class for network blocks."""
Expand Down Expand Up @@ -492,11 +495,17 @@ def forward(
query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

# TODO: remove magic number
num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference
num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE # reduce memory for inference

if num_chunks > 1:
edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0)
#edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
#edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0)
edge_attr_list, edge_index_list = sorted_1hop_chunks(
num_nodes=size, edge_attr=edges, edge_index=edge_index,
num_chunks=num_chunks
)

out = torch.zeros_like(query, dtype = query.dtype)
for i in range(num_chunks):
out1 = self.conv(
query=query,
Expand All @@ -506,20 +515,44 @@ def forward(
edge_index=edge_index_list[i],
size=size,
)
if i == 0:
"""if i == 0:
out = torch.zeros_like(out1)
out = out + out1
out = out + out1"""
out.add_(out1)

else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)
out = self.projection(out + x_r)
#out = self.projection(out + x_r)


projected = []
for chunk in torch.tensor_split(out + x_r, num_chunks, dim = 0):
projected.append(self.projection(chunk))
del chunk

out = torch.cat(projected, dim = 0)

out = out + x_skip[1]
nodes_new_dst = self.node_dst_mlp(out) + out

nodes_new_src = self.node_src_mlp(x_skip[0]) + x_skip[0] if self.update_src_nodes else x_skip[0]
nodes_new_dst = []
for chunk in out.tensor_split(num_chunks, dim = 0 ):
nodes_new_dst.append(self.node_dst_mlp(chunk) + chunk)
del chunk

nodes_new_dst = torch.cat(nodes_new_dst, dim = 0)

if self.update_src_nodes:
nodes_new_src = []
for chunk in x_skip[0].tensor_split(num_chunks, dim=0):
nodes_new_src.append(self.node_src_mlp(chunk) + chunk)
del chunk

nodes_new_src = torch.cat(nodes_new_src, dim = 0)

else:
nodes_new_src = x_skip[0]
nodes_new = (nodes_new_src, nodes_new_dst)

return nodes_new, edge_attr
Expand Down Expand Up @@ -602,11 +635,13 @@ def forward(
query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

# TODO: Is this alright?
num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference
num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE # reduce memory for inference

if num_chunks > 1:
edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0)

out = torch.zeros_like(query, dtype = query.dtype)
for i in range(num_chunks):
out1 = self.conv(
query=query,
Expand All @@ -616,16 +651,30 @@ def forward(
edge_index=edge_index_list[i],
size=size,
)
if i == 0:
out = torch.zeros_like(out1)
out = out + out1
#if i == 0:
# out = torch.zeros_like(out1)
#out = out + out1
out.add_(out1)
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)

"""projected_chunks = []
for chunk in torch.tensor_split(out + x_r, num_chunks ,dim = 0):
projected_chunks.append(self.projection(chunk))
del chunk
out = torch.cat(projected_chunks, dim = 0)"""
out = self.projection(out + x_r)

out = out + x_skip
nodes_new = self.node_dst_mlp(out) + out
"""nodes_new_chunks = []
for chunk in torch.tensor_split(out, num_chunks, dim = 0):
nodes_new_chunks.append(self.node_dst_mlp(chunk) + chunk)
del chunk
nodes_new = torch.cat(nodes_new_chunks, dim = 0)"""

return nodes_new, edge_attr
4 changes: 2 additions & 2 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from anemoi.models.distributed.graph import gather_tensor
from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.khop_edges import sort_edges_1hop
from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding
from anemoi.models.distributed.shapes import change_channels_in_shape
from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.block import GraphConvMapperBlock
Expand Down Expand Up @@ -465,7 +465,7 @@ def __init__(
def prepare_edges(self, size, batch_size, model_comm_group=None):
edge_attr = self.trainable(self.edge_attr, batch_size)
edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size)
edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop(
edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop_sharding(
size, edge_attr, edge_index, model_comm_group
)

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.checkpoint import checkpoint

from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.khop_edges import sort_edges_1hop
from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding
from anemoi.models.distributed.shapes import change_channels_in_shape
from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.chunk import GNNProcessorChunk
Expand Down Expand Up @@ -229,7 +229,7 @@ def forward(
edge_attr = self.trainable(self.edge_attr, batch_size)
edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size)
target_nodes = sum(x[0] for x in shape_nodes)
edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop(
edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop_sharding(
target_nodes,
edge_attr,
edge_index,
Expand Down

0 comments on commit 7eca6c8

Please sign in to comment.