Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11 add configurability to dropout in multiheadselfattention module #12

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.


### Changed

- Update CI to inherit from common infrastructue reusable workflows
Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ def __init__(
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
dropout: float = 0.0,
dropout_p: float = 0.0,
):
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.dropout = dropout
self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
Expand Down Expand Up @@ -86,15 +86,22 @@ def forward(
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
out = self.attention(query, key, value, is_causal=False) # expects (batch heads grid variable) format
out = self.attention(
query,
key,
value,
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
Expand Down
12 changes: 10 additions & 2 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def forward(
class TransformerProcessorBlock(BaseBlock):
"""Transformer block with MultiHeadSelfAttention and MLPs."""

def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size: int):
def __init__(
self,
num_channels: int,
hidden_dim: int,
num_heads: int,
activation: str,
window_size: int,
dropout_p: float = 0.0,
):
super().__init__()

try:
Expand All @@ -72,7 +80,7 @@ def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size:
window_size=window_size,
bias=False,
is_causal=False,
dropout=0.0,
dropout_p=dropout_p,
)

self.mlp = nn.Sequential(
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
) -> None:
"""Initialize TransformerProcessor.

Expand All @@ -88,6 +89,8 @@ def __init__(
ratio of mlp hidden dimension to embedding dimension, default 4
activation : str, optional
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -98,6 +101,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
)

def forward(
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
cpu_offload: bool = False,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -113,6 +114,8 @@ def __init__(
ratio of mlp hidden dimension to embedding dimension, default 4
activation : str, optional
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -133,6 +136,7 @@ def __init__(
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
)

self.offload_layers(cpu_offload)
Expand Down
13 changes: 10 additions & 3 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ class TestTransformerProcessorBlock:
num_heads=st.integers(min_value=1, max_value=10),
activation=st.sampled_from(["ReLU", "GELU", "Tanh"]),
window_size=st.integers(min_value=1, max_value=512),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size):
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size)
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
)
assert isinstance(block, TransformerProcessorBlock)

assert isinstance(block.layer_norm1, nn.LayerNorm)
Expand All @@ -49,6 +52,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
window_size=st.integers(min_value=1, max_value=512),
shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3),
batch_size=st.integers(min_value=1, max_value=40),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_forward_output(
Expand All @@ -60,9 +64,12 @@ def test_forward_output(
window_size,
shapes,
batch_size,
dropout_p,
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size)
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
)

x = torch.randn((batch_size, num_channels))

Expand Down
4 changes: 4 additions & 0 deletions tests/layers/chunk/test_chunk_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def init(self):
mlp_hidden_ratio: int = 4
activation: str = "GELU"
window_size: int = 13
dropout_p: float = 0.1

# num_heads must be evenly divisible by num_channels for MHSA
return (
Expand All @@ -29,6 +30,7 @@ def init(self):
mlp_hidden_ratio,
activation,
window_size,
dropout_p,
)

@pytest.fixture
Expand All @@ -40,6 +42,7 @@ def processor_chunk(self, init):
mlp_hidden_ratio,
activation,
window_size,
dropout_p,
) = init
return TransformerProcessorChunk(
num_channels=num_channels,
Expand All @@ -48,6 +51,7 @@ def processor_chunk(self, init):
mlp_hidden_ratio=mlp_hidden_ratio,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
)

def test_all_blocks(self, processor_chunk):
Expand Down
6 changes: 6 additions & 0 deletions tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def transformer_processor_init():
cpu_offload = False
num_heads = 16
mlp_hidden_ratio = 4
dropout_p = 0.1
return (
num_layers,
window_size,
Expand All @@ -30,6 +31,7 @@ def transformer_processor_init():
cpu_offload,
num_heads,
mlp_hidden_ratio,
dropout_p,
)


Expand All @@ -44,6 +46,7 @@ def transformer_processor(transformer_processor_init):
cpu_offload,
num_heads,
mlp_hidden_ratio,
dropout_p,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -54,6 +57,7 @@ def transformer_processor(transformer_processor_init):
cpu_offload=cpu_offload,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
)


Expand All @@ -67,6 +71,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_cpu_offload,
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
assert transformer_processor.num_chunks == num_chunks
Expand All @@ -84,6 +89,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_cpu_offload,
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
) = transformer_processor_init
gridsize = 100
batch_size = 1
Expand Down
16 changes: 10 additions & 6 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,32 @@
@given(
num_heads=st.integers(min_value=1, max_value=50),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p):
embed_dim = (
num_heads * embed_dim_multiplier
) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

assert isinstance(mhsa, nn.Module)
assert mhsa.num_heads == num_heads
assert mhsa.embed_dim == embed_dim
assert mhsa.head_dim == embed_dim // num_heads
assert dropout_p == mhsa.dropout_p


@pytest.mark.gpu
@given(
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(deadline=None)
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

x = torch.randn(batch_size * 2, embed_dim)
shapes = [list(x.shape)]
Expand All @@ -54,10 +57,11 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
shapes = [list(x.shape)]
Expand Down
Loading