From 91533c6606ac5114175a2680c18d805f391289ee Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 1 Oct 2024 08:05:47 +0000 Subject: [PATCH] feat: make alibi_slope cinfigurable in block, chunk processor --- src/anemoi/models/layers/block.py | 4 +++- src/anemoi/models/layers/chunk.py | 6 ++++-- src/anemoi/models/layers/processor.py | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index aec3a3f..6ce72f3 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -64,7 +64,8 @@ def __init__( window_size: int, dropout_p: float = 0.0, use_flash_attention: bool = False, - softcap: float = 0.0, + softcap: float | None = 0.0, + alibi_slopes: Tensor | None = None, ): super().__init__() @@ -85,6 +86,7 @@ def __init__( dropout_p=dropout_p, use_flash_attention=use_flash_attention, softcap=softcap, + alibi_slopes=alibi_slopes, ) self.mlp = nn.Sequential( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index ab499ac..c4d4278 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -75,7 +75,8 @@ def __init__( activation: str = "GELU", dropout_p: float = 0.0, use_flash_attention: bool = False, - softcap: float = 0.0, + softcap: float | None = 0.0, + alibi_slopes: Tensor | None = None, ) -> None: """Initialize TransformerProcessor. @@ -104,8 +105,9 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, - softcap=softcap, use_flash_attention=use_flash_attention, + softcap=softcap, + alibi_slopes=alibi_slopes, ) def forward( diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 3ef7ddb..2f8ca40 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -97,7 +97,8 @@ def __init__( mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, use_flash_attention: bool = False, - softcap: float = 0.0, + softcap: float | None = 0.0, + alibi_slopes: Tensor | None = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -141,6 +142,7 @@ def __init__( dropout_p=dropout_p, use_flash_attention=use_flash_attention, softcap=softcap, + alibi_slopes=alibi_slopes, ) self.offload_layers(cpu_offload)