diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f5492..f508306 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -12,13 +12,16 @@ from typing import Optional import einops +import torch from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.typing import PairTensor try: from flash_attn import flash_attn_func as attn_func except ImportError: + from flash_attn.layers.rotary import RotaryEmbedding from torch.nn.functional import scaled_dot_product_attention as attn_func _FLASH_ATTENTION_AVAILABLE = False @@ -27,6 +30,7 @@ from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence +from anemoi.models.layers.utils import AutocastLayerNorm LOGGER = logging.getLogger(__name__) @@ -42,6 +46,8 @@ def __init__( is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, + qk_norm: bool = False, + rotary_embeddings: bool = False, ): super().__init__() @@ -55,8 +61,12 @@ def __init__( self.window_size = (window_size, window_size) # flash attention self.dropout_p = dropout_p self.is_causal = is_causal + self.qk_norm = qk_norm + self.rotary_embeddings = rotary_embeddings - self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + self.lin_q = nn.Linear(embed_dim, embed_dim, bias=bias) + self.lin_k = nn.Linear(embed_dim, embed_dim, bias=bias) + self.lin_v = nn.Linear(embed_dim, embed_dim, bias=bias) self.attention = attn_func if not _FLASH_ATTENTION_AVAILABLE: @@ -64,11 +74,23 @@ def __init__( self.projection = nn.Linear(embed_dim, embed_dim, bias=True) - def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None - ) -> Tensor: - query, key, value = self.lin_qkv(x).chunk(3, -1) + if self.qk_norm: + self.q_norm = AutocastLayerNorm(self.head_dim, bias=False) + self.k_norm = AutocastLayerNorm(self.head_dim, bias=False) + if self.rotary_embeddings: # find alternative implementation + assert _FLASH_ATTENTION_AVAILABLE, "Rotary embeddings require flash attention" + self.rotary_emb = RotaryEmbedding(dim=self.head_dim) + + def attention_computation( + self, + query: Tensor, + key: Tensor, + value: Tensor, + shapes: list, + batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + ) -> Tensor: if model_comm_group: assert ( model_comm_group.size() == 1 or batch_size == 1 @@ -83,16 +105,28 @@ def forward( ) for t in (query, key, value) ) - query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) dropout_p = self.dropout_p if self.training else 0.0 + if self.qk_norm: + query = self.q_norm(query) + key = self.k_norm(key) + if _FLASH_ATTENTION_AVAILABLE: query, key, value = ( einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) ) + if self.rotary_embeddings: # can this be done in a better way? + key = key.unsqueeze(-3) + value = value.unsqueeze(-3) + keyvalue = torch.cat((key, value), dim=-3) + query, keyvalue = self.rotary_emb( + query, keyvalue, max_seqlen=max(keyvalue.shape[1], query.shape[1]) + ) # assumption seq const + key = keyvalue[:, :, 0, ...] + value = keyvalue[:, :, 1, ...] out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p) out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") else: @@ -103,10 +137,29 @@ def forward( is_causal=False, dropout_p=dropout_p, ) # expects (batch heads grid variable) format - out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") + return self.projection(out) + + def forward( + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + ) -> Tensor: + query = self.lin_q(x) + key = self.lin_k(x) + value = self.lin_v(x) + return self.attention_computation(query, key, value, shapes, batch_size, model_comm_group) - out = self.projection(out) - return out +class MultiHeadCrossAttention(MultiHeadSelfAttention): + """Multi Head Cross Attention Pytorch Layer.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, x: PairTensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + ) -> Tensor: + query = self.lin_q(x[1]) + key = self.lin_k(x[0]) + value = self.lin_v(x[0]) + return self.attention_computation(query, key, value, shapes, batch_size, model_comm_group) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6..24ea489 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -28,6 +28,7 @@ from anemoi.models.distributed.khop_edges import sort_edges_1hop_chunks from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence +from anemoi.models.layers.attention import MultiHeadCrossAttention from anemoi.models.layers.attention import MultiHeadSelfAttention from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv @@ -105,6 +106,53 @@ def forward( return x +class TransformerMapperBlock(TransformerProcessorBlock): + """Transformer mapper block with MultiHeadCrossAttention and MLPs.""" + + def __init__( + self, + num_channels: int, + hidden_dim: int, + num_heads: int, + activation: str, + window_size: int, + dropout_p: float = 0.0, + ): + super().__init__( + num_channels=num_channels, + hidden_dim=hidden_dim, + num_heads=num_heads, + activation=activation, + window_size=window_size, + dropout_p=dropout_p, + ) + + self.attention = MultiHeadCrossAttention( + num_heads=num_heads, + embed_dim=num_channels, + window_size=window_size, + bias=False, + is_causal=False, + dropout_p=dropout_p, + ) + + self.layer_norm_src = nn.LayerNorm(num_channels) + + def forward( + self, + x: OptPairTensor, + shapes: list, + batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + ) -> Tensor: + # Need to be out of place for gradient propagation + x_src = self.layer_norm_src(x[0]) + x_dst = self.layer_norm1(x[1]) + x_dst = x_dst + self.attention((x_src, x_dst), shapes, batch_size, model_comm_group=model_comm_group) + x_dst = x_dst + self.mlp(self.layer_norm2(x_dst)) + return (x_src, x_dst), None # logic expects return of edge_attr + + class GraphConvBaseBlock(BaseBlock): """Message passing block with MLPs for node embeddings.""" @@ -180,7 +228,7 @@ def __ini__( **kwargs, ): super().__init__( - self, + self, # is this correct? in_channels=in_channels, out_channels=out_channels, mlp_extra_layers=mlp_extra_layers, diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 1ae4503..471901f 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -29,6 +29,7 @@ from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.block import GraphConvMapperBlock from anemoi.models.layers.block import GraphTransformerMapperBlock +from anemoi.models.layers.block import TransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP @@ -703,3 +704,224 @@ def forward( _, x_dst = super().forward(x, batch_size, shard_shapes, model_comm_group) return x_dst + + +class TransformerBaseMapper(BaseMapper): + """Transformer Base Mapper from hidden -> data or data -> hidden.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + num_chunks: int = 1, + cpu_offload: bool = False, + activation: str = "GELU", + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + window_size: Optional[int] = None, + dropout_p: float = 0.0, + ) -> None: + """Initialize TransformerBaseMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + trainable_size : int + Trainable tensor of edge + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + activation : str, optional + Activation function, by default "GELU" + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + """ + super().__init__( + in_channels_src, + in_channels_dst, + hidden_dim, + out_channels_dst=out_channels_dst, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + activation=activation, + ) + + self.proc = TransformerMapperBlock( + num_channels=hidden_dim, + hidden_dim=mlp_hidden_ratio * hidden_dim, + num_heads=num_heads, + activation=activation, + window_size=window_size, + dropout_p=dropout_p, + ) + + self.offload_layers(cpu_offload) + + self.emb_nodes_dst = nn.Linear(self.in_channels_dst, self.hidden_dim) + + def forward( + self, + x: PairTensor, + batch_size: int, + shard_shapes: tuple[tuple[int], tuple[int]], + model_comm_group: Optional[ProcessGroup] = None, + ) -> PairTensor: + + x_src, x_dst, shapes_src, shapes_dst = self.pre_process(x, shard_shapes, model_comm_group) + + (x_src, x_dst), _ = self.proc( + (x_src, x_dst), + (shapes_src, shapes_dst), + batch_size, + model_comm_group, + ) + + x_dst = self.post_process(x_dst, shapes_dst, model_comm_group) + + return x_dst + + +class TransformerForwardMapper(ForwardMapperPreProcessMixin, TransformerBaseMapper): + """Transformer Mapper from data -> hidden.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + num_chunks: int = 1, + cpu_offload: bool = False, + activation: str = "GELU", + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + window_size: Optional[int] = None, + dropout_p: float = 0.0, + **kwargs, # accept not needed extra arguments like subgraph etc. + ) -> None: + """Initialize TransformerForwardMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + trainable_size : int + Trainable tensor of edge + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + activation : str, optional + Activation function, by default "GELU" + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + """ + super().__init__( + in_channels_src, + in_channels_dst, + hidden_dim, + out_channels_dst=out_channels_dst, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + activation=activation, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + window_size=window_size, + dropout_p=dropout_p, + ) + + self.emb_nodes_src = nn.Linear(self.in_channels_src, self.hidden_dim) + + def forward( + self, + x: PairTensor, + batch_size: int, + shard_shapes: tuple[tuple[int], tuple[int], tuple[int], tuple[int]], + model_comm_group: Optional[ProcessGroup] = None, + ) -> PairTensor: + x_dst = super().forward(x, batch_size, shard_shapes, model_comm_group) + return x[0], x_dst + + +class TransformerBackwardMapper(BackwardMapperPostProcessMixin, TransformerBaseMapper): + """Graph Transformer Mapper from hidden -> data.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + num_chunks: int = 1, + cpu_offload: bool = False, + activation: str = "GELU", + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + window_size: Optional[int] = None, + dropout_p: float = 0.0, + **kwargs, # accept not needed extra arguments like subgraph etc. + ) -> None: + """Initialize TransformerBackwardMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + trainable_size : int + Trainable tensor of edge + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + activation : str, optional + Activation function, by default "GELU" + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + """ + super().__init__( + in_channels_src, + in_channels_dst, + hidden_dim, + out_channels_dst=out_channels_dst, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + activation=activation, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + window_size=window_size, + dropout_p=dropout_p, + ) + + self.node_data_extractor = nn.Sequential( + nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) + ) + + def pre_process(self, x, shard_shapes, model_comm_group=None): + x_src, x_dst, shapes_src, shapes_dst = super().pre_process(x, shard_shapes, model_comm_group) + shapes_src = change_channels_in_shape(shapes_src, self.hidden_dim) + x_dst = shard_tensor(x_dst, 0, shapes_dst, model_comm_group) + x_dst = self.emb_nodes_dst(x_dst) + shapes_dst = change_channels_in_shape(shapes_dst, self.hidden_dim) + return x_src, x_dst, shapes_src, shapes_dst