From 4d17e3324f5449ba2ce1f2e33002fbf12f5b0b5e Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Mon, 4 Mar 2024 19:48:41 +0100
Subject: [PATCH 01/15] feat: implemented RMSLayerNorm

---
 src/modalities/models/components/__init__.py  |  0
 .../models/components/layer_norms.py          | 80 +++++++++++++++++++
 2 files changed, 80 insertions(+)
 create mode 100644 src/modalities/models/components/__init__.py
 create mode 100644 src/modalities/models/components/layer_norms.py

diff --git a/src/modalities/models/components/__init__.py b/src/modalities/models/components/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/modalities/models/components/layer_norms.py b/src/modalities/models/components/layer_norms.py
new file mode 100644
index 00000000..45bd57e4
--- /dev/null
+++ b/src/modalities/models/components/layer_norms.py
@@ -0,0 +1,80 @@
+from typing import Annotated
+
+import torch
+import torch.nn as nn
+from pydantic import BaseModel, Field
+from torch.nn import functional as F
+
+from modalities.config.lookup_enum import LookupEnum
+
+
+class LayerNormIF(nn.Module):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+
+class RMSLayerNorm(LayerNormIF):
+    def __init__(self, ndim: int, epsilon: float = 1e-6):
+        """
+        Initialize the RMSNorm normalization layer.
+        Original paper: https://arxiv.org/pdf/1910.07467.pdf
+        Source code adopted from https://github.com/facebookresearch/llama/blob/a0a4da8b497c566403941ceec47c2512ecf9dd20/llama/model.py#L34C1-L77C36
+
+        Args:
+            ndim (int): The dimension of the input tensor.
+            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+        Attributes:
+            eps (float): A small value added to the denominator for numerical stability.
+            weight (nn.Parameter): Learnable scaling parameter.
+
+        """
+        super().__init__()
+        self.epsilon = epsilon
+        self.weight = nn.Parameter(torch.ones(ndim))
+
+    def _norm(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+class ZLayerNorm(LayerNormIF):
+    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
+
+    def __init__(self, ndim: int, bias: bool, epsilon: float = 1e-6):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(ndim))
+        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+        self.epsilon = epsilon
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return F.layer_norm(
+            input=x,
+            normalized_shape=self.weight.shape,
+            weight=self.weight,
+            bias=self.bias,
+            eps=self.epsilon,
+        )
+
+
+class LayerNorms(LookupEnum):
+    """
+    An enumeration of the different layer normalization techniques.
+    """
+
+    RMSNorm = RMSLayerNorm
+    ZLayerNorm = ZLayerNorm
+
+
+class ZLayerNormConfig(BaseModel):
+    ndim: Annotated[int, Field(strict=True, ge=1)]
+    bias: Annotated[bool, Field(default=True)]
+    epsilon: Annotated[float, Field(gt=0, default=1e-6)]
+
+
+class RMSLayerNormConfig(BaseModel):
+    ndim: Annotated[int, Field(strict=True, ge=1)]
+    epsilon: Annotated[float, Field(gt=0, default=1e-6)]

From 8dd74a4f98968be85c1600200cb6fd0aa5700553 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Mon, 4 Mar 2024 19:49:29 +0100
Subject: [PATCH 02/15] feat: added layer norm to GPT2Model

---
 src/modalities/models/gpt2/gpt2_model.py | 85 +++++++++++-------------
 1 file changed, 40 insertions(+), 45 deletions(-)

diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py
index b633d1b4..bd36f4d2 100644
--- a/src/modalities/models/gpt2/gpt2_model.py
+++ b/src/modalities/models/gpt2/gpt2_model.py
@@ -9,6 +9,8 @@
 from pydantic import BaseModel, Field, model_validator
 from torch.nn import functional as F
 
+from modalities.config.config import PydanticLayerNormIFType
+from modalities.models.components.layer_norms import LayerNormIF
 from modalities.models.model import NNModel
 
 # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
@@ -47,11 +49,13 @@ class GPT2LLMConfig(BaseModel):
     ffn_hidden: Annotated[int, Field(strict=True, ge=1)]
 
     dropout: Annotated[float, Field(strict=True, ge=0.0)]
-    bias: bool  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
-    attention: AttentionConfig
-    activation: ActivationType
-    epsilon: Annotated[float, Field(strict=True, ge=0.0)]
+    bias: bool  # True: bias in Linears like GPT-2. False: a bit better and faster
+    attention_config: AttentionConfig
+    activation_type: ActivationType
     weight_init: WeightInitailizationConfig
+    attention_norm: PydanticLayerNormIFType
+    ffn_norm: PydanticLayerNormIFType
+    lm_head_norm: PydanticLayerNormIFType
 
     @model_validator(mode="after")
     def validate_sizes(self) -> "GPT2LLMConfig":
@@ -64,35 +68,16 @@ def validate_sizes(self) -> "GPT2LLMConfig":
         return self
 
 
-class LayerNorm(nn.Module):
-    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
-
-    def __init__(self, ndim: int, bias: bool, epsilon: float):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(ndim))
-        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
-        self.epsilon = epsilon
-
-    def forward(self, input: torch.Tensor) -> torch.Tensor:
-        return F.layer_norm(
-            input=input,
-            normalized_shape=self.weight.shape,
-            weight=self.weight,
-            bias=self.bias,
-            eps=self.epsilon,
-        )
-
-
 class CausalSelfAttention(nn.Module):
     def __init__(
-        self, n_head: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int
+        self, n_head: int, n_embd: int, attention_config: AttentionConfig, bias: bool, dropout: float, block_size: int
     ):
         super().__init__()
         assert n_embd % n_head == 0
         # key, query, value projections for all heads, but in a batch
         self.c_attn = nn.Linear(
             in_features=n_embd,
-            out_features=attention.scaling_factor * n_embd,
+            out_features=attention_config.scaling_factor * n_embd,
             bias=bias,
         )
 
@@ -109,7 +94,7 @@ def __init__(
         self.n_head = n_head
         self.n_embd = n_embd
         self.dropout = dropout
-        self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION
+        self.flash = attention_config.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION
 
         if not self.flash:
             # causal mask to ensure that attention is only applied to the left in the input sequence
@@ -181,32 +166,39 @@ def __init__(
         self,
         n_embd: int,
         bias: bool,
-        epsilon: float,
-        activation: ActivationType,
+        activation_type: ActivationType,
         n_head: int,
-        attention: AttentionConfig,
+        attention_config: AttentionConfig,
         dropout: float,
         block_size: int,
         ffn_hidden: int,
+        attention_norm: LayerNormIF,
+        ffn_norm: LayerNormIF,
     ):
         super().__init__()
-        self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)
+        self.attention_norm = attention_norm
+        self.ffn_norm = ffn_norm
         self.attn = CausalSelfAttention(
-            n_head=n_head, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size
+            n_head=n_head,
+            n_embd=n_embd,
+            attention_config=attention_config,
+            bias=bias,
+            dropout=dropout,
+            block_size=block_size,
         )
-        self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)
-
-        if activation == ActivationType.GELU:
+        if activation_type == ActivationType.GELU:
             self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
-        elif activation == ActivationType.FUSED_SWIGLU:
+        elif activation_type == ActivationType.FUSED_SWIGLU:
             hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)
             self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False)
         else:
-            raise Exception("unimplemented activation")
+            raise NotImplementedError("unimplemented activation")
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x = x + self.attn(self.ln_1(x))
-        x = x + self.mlp(self.ln_2(x))
+        x = self.attention_norm(x)
+        x = x + self.attn(x)
+        x = self.ffn_norm(x)
+        x = x + self.mlp(x)
         return x
 
 
@@ -223,10 +215,12 @@ def __init__(
         ffn_hidden: int,
         dropout: float,
         bias: bool,
-        attention: AttentionConfig,
-        activation: ActivationType,
-        epsilon: float,
+        attention_config: AttentionConfig,
+        activation_type: ActivationType,
         weight_init: WeightInitailizationConfig,
+        attention_norm: LayerNormIF,
+        ffn_norm: LayerNormIF,
+        lm_head_norm: LayerNormIF,
     ):
         super().__init__()
         self.sample_key = sample_key
@@ -246,18 +240,19 @@ def __init__(
                         GPT2Block(
                             n_embd=n_embd,
                             bias=bias,
-                            epsilon=epsilon,
-                            activation=activation,
+                            activation_type=activation_type,
                             n_head=n_head,
-                            attention=attention,
+                            attention_config=attention_config,
                             dropout=dropout,
                             block_size=block_size,
                             ffn_hidden=ffn_hidden,
+                            attention_norm=attention_norm,
+                            ffn_norm=ffn_norm,
                         )
                         for _ in range(n_layer)
                     ]
                 ),
-                ln_f=LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon),
+                ln_f=lm_head_norm,
             )
         )
         self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size, bias=False)

From 8bfec3ea17d8da9fe07001b2a1c7ac3cfe516bec Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Mon, 4 Mar 2024 19:49:58 +0100
Subject: [PATCH 03/15] feat: wired up RMSnorm and ZLayerNorm in registry

---
 src/modalities/config/config.py       | 2 ++
 src/modalities/registry/components.py | 4 ++++
 2 files changed, 6 insertions(+)

diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py
index 353cf14e..7238ab2a 100644
--- a/src/modalities/config/config.py
+++ b/src/modalities/config/config.py
@@ -21,6 +21,7 @@
 from modalities.dataloader.dataloader import LLMDataLoader
 from modalities.logging_broker.subscriber import MessageSubscriberIF
 from modalities.loss_functions import Loss
+from modalities.models.components.layer_norms import LayerNormIF
 from modalities.models.gpt2.collator import CollateFnIF
 from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support
 from modalities.util import get_date_of_run, parse_enum_by_name
@@ -61,6 +62,7 @@ def __get_pydantic_core_schema__(
 PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
 PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)]
 PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)]
+PydanticLayerNormIFType = Annotated[LayerNormIF, PydanticThirdPartyTypeIF(LayerNormIF)]
 
 
 class ProcessGroupBackendType(LookupEnum):
diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py
index 40dfccfd..3257e1fa 100644
--- a/src/modalities/registry/components.py
+++ b/src/modalities/registry/components.py
@@ -43,6 +43,7 @@
     ResultsSubscriberFactory,
 )
 from modalities.loss_functions import CLMCrossEntropyLoss
+from modalities.models.components.layer_norms import RMSLayerNorm, RMSLayerNormConfig, ZLayerNorm, ZLayerNormConfig
 from modalities.models.gpt2.collator import GPT2LLMCollateFn
 from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig
 from modalities.models.huggingface.huggingface_models import (
@@ -154,4 +155,7 @@ class ComponentEntity:
         ResultsSubscriberFactory.get_wandb_result_subscriber,
         WandBEvaluationResultSubscriberConfig,
     ),
+    # layer norms
+    ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig),
+    ComponentEntity("layer_norm", "z_layer_norm", ZLayerNorm, ZLayerNormConfig),
 ]

From fa63c9585d06b9898032fa05558dc42fee08f725 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Mon, 4 Mar 2024 19:53:36 +0100
Subject: [PATCH 04/15] feat: created YAML config with ZLayerNorm and
 RMSLayerNorm suppport

---
 .../config_example_mem_map_dataset.yaml       | 31 ++++++++++++++++---
 1 file changed, 26 insertions(+), 5 deletions(-)

diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml
index 62498856..74cfd26f 100644
--- a/config_files/config_example_mem_map_dataset.yaml
+++ b/config_files/config_example_mem_map_dataset.yaml
@@ -31,7 +31,7 @@ train_dataset:
   component_key: dataset
   variant_key: packed_mem_map_dataset_megatron
   config: 
-    raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_16777216.pbin
+    raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1050391.pbin
     block_size: ${settings.training.sequence_length}
     sample_key: ${settings.referencing_keys.sample_key}
 
@@ -142,15 +142,36 @@ model:
     ffn_hidden: 2048
     n_embd: 768
     dropout: 0.0
-    bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
-    attention:
+    bias: true # True: bias in Linears, like GPT-2. False: a bit better and faster
+    attention_config:
       attention_type: pytorch_flash_attention
       scaling_factor: 3
-    activation: gelu
-    epsilon: 1e-5
+    activation_type: gelu
     weight_init:
       mean: 0.0
       std: 0.02
+    attention_norm:
+      component_key: layer_norm
+      variant_key: z_layer_norm
+      config:
+        ndim: ${model.config.n_embd}
+        bias: true
+        eps: 1e-5
+    ffn_norm:
+      component_key: layer_norm
+      variant_key: z_layer_norm
+      config:
+        ndim: ${model.config.n_embd}
+        bias: true
+        eps: 1e-5
+    lm_head_norm:
+      component_key: layer_norm
+      variant_key: z_layer_norm
+      config:
+        ndim: ${model.config.n_embd}
+        bias: true
+        eps: 1e-5
+
 
 wrapped_model:
   component_key: model

From 04b4f4534b20901bf80ee6d9bab3a6b89eb62f16 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Thu, 7 Mar 2024 16:52:09 +0100
Subject: [PATCH 05/15] refactor: added workaround to instantiate multiple
 LayerNorm layers from single layernorm instance

---
 .../models/components/layer_norms.py           | 18 +++++++++++++++---
 src/modalities/models/gpt2/gpt2_model.py       |  5 +++--
 2 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/src/modalities/models/components/layer_norms.py b/src/modalities/models/components/layer_norms.py
index 45bd57e4..85f3647d 100644
--- a/src/modalities/models/components/layer_norms.py
+++ b/src/modalities/models/components/layer_norms.py
@@ -12,6 +12,9 @@ class LayerNormIF(nn.Module):
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         raise NotImplementedError
 
+    def __copy__(self):
+        raise NotImplementedError
+
 
 class RMSLayerNorm(LayerNormIF):
     def __init__(self, ndim: int, epsilon: float = 1e-6):
@@ -40,6 +43,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         output = self._norm(x.float()).type_as(x)
         return output * self.weight
 
+    def __copy__(self):
+        copied_instance = RMSLayerNorm(self.weight.shape[0], self.epsilon)
+        return copied_instance
+
 
 class ZLayerNorm(LayerNormIF):
     """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
@@ -47,17 +54,22 @@ class ZLayerNorm(LayerNormIF):
     def __init__(self, ndim: int, bias: bool, epsilon: float = 1e-6):
         super().__init__()
         self.weight = nn.Parameter(torch.ones(ndim))
-        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+        self.bias_tensor = nn.Parameter(torch.zeros(ndim)) if bias else None
         self.epsilon = epsilon
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return F.layer_norm(
+        normalized_x = F.layer_norm(
             input=x,
             normalized_shape=self.weight.shape,
             weight=self.weight,
-            bias=self.bias,
+            bias=self.bias_tensor,
             eps=self.epsilon,
         )
+        return normalized_x
+
+    def __copy__(self):
+        copied_instance = ZLayerNorm(self.weight.shape[0], self.bias_tensor is not None, self.epsilon)
+        return copied_instance
 
 
 class LayerNorms(LookupEnum):
diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py
index bd36f4d2..8e1d5ec3 100644
--- a/src/modalities/models/gpt2/gpt2_model.py
+++ b/src/modalities/models/gpt2/gpt2_model.py
@@ -1,4 +1,5 @@
 import math
+from copy import copy
 from enum import Enum
 from functools import partial
 from typing import Annotated, Dict
@@ -246,8 +247,8 @@ def __init__(
                             dropout=dropout,
                             block_size=block_size,
                             ffn_hidden=ffn_hidden,
-                            attention_norm=attention_norm,
-                            ffn_norm=ffn_norm,
+                            attention_norm=copy(attention_norm),
+                            ffn_norm=copy(ffn_norm),
                         )
                         for _ in range(n_layer)
                     ]

From 78a24618161a0f5acb71923b4cbbfb26c788ac62 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Thu, 7 Mar 2024 17:58:28 +0100
Subject: [PATCH 06/15] test: added layer norm tests

---
 tests/models/__init__.py                    |  0
 tests/models/components/__init__.py         |  0
 tests/models/components/test_layer_norms.py | 47 +++++++++++++++++++++
 3 files changed, 47 insertions(+)
 create mode 100644 tests/models/__init__.py
 create mode 100644 tests/models/components/__init__.py
 create mode 100644 tests/models/components/test_layer_norms.py

diff --git a/tests/models/__init__.py b/tests/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/models/components/__init__.py b/tests/models/components/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/models/components/test_layer_norms.py b/tests/models/components/test_layer_norms.py
new file mode 100644
index 00000000..8e118a3b
--- /dev/null
+++ b/tests/models/components/test_layer_norms.py
@@ -0,0 +1,47 @@
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from modalities.models.components.layer_norms import RMSLayerNorm, ZLayerNorm
+
+
+@pytest.fixture
+def rms_layer_norm() -> RMSLayerNorm:
+    norm = RMSLayerNorm(ndim=3, epsilon=1e-6)
+    weight_tensor = torch.Tensor([1, 2, 3])
+    norm.weight = nn.Parameter(weight_tensor)
+    return norm
+
+
+@pytest.fixture
+def z_layer_norm() -> ZLayerNorm:
+    norm = ZLayerNorm(ndim=3, bias=True, epsilon=1e-6)
+    return norm
+
+
+def test_rms_layer_norm_forward(rms_layer_norm):
+    x = torch.Tensor([0.1, 0.2, 0.3])
+    output = rms_layer_norm(x)
+    ref_x = x / np.sqrt((0.1**2 + 0.2**2 + 0.3**2) / 3 + 1e-6)
+    ref_tensor = ref_x * rms_layer_norm.weight
+
+    assert output.shape == x.shape
+    assert all(output == ref_tensor)
+
+
+def test_z_layer_norm_forward(z_layer_norm):
+    x = torch.Tensor([0.1, 0.2, 0.3])
+    output = z_layer_norm(x)
+    ndim = z_layer_norm.weight.shape[0]
+    ref_tensor = F.layer_norm(
+        input=x,
+        normalized_shape=z_layer_norm.weight.shape,
+        weight=nn.Parameter(torch.ones(ndim)),
+        bias=nn.Parameter(torch.zeros(ndim)) if z_layer_norm.bias_tensor is not None else None,
+        eps=z_layer_norm.epsilon,
+    )
+
+    assert output.shape == x.shape
+    assert all(output == ref_tensor)

From eebd7574ee04153b98fa36bc9cb73c19fa271549 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 16:43:32 +0100
Subject: [PATCH 07/15] refactor: type annotations for layer norms are now
 nn.Module

---
 src/modalities/models/gpt2/gpt2_model.py | 25 ++++++++++++------------
 src/modalities/registry/components.py    |  5 +++--
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py
index 8e1d5ec3..acb144ba 100644
--- a/src/modalities/models/gpt2/gpt2_model.py
+++ b/src/modalities/models/gpt2/gpt2_model.py
@@ -1,5 +1,5 @@
 import math
-from copy import copy
+from copy import deepcopy
 from enum import Enum
 from functools import partial
 from typing import Annotated, Dict
@@ -10,8 +10,7 @@
 from pydantic import BaseModel, Field, model_validator
 from torch.nn import functional as F
 
-from modalities.config.config import PydanticLayerNormIFType
-from modalities.models.components.layer_norms import LayerNormIF
+from modalities.config.config import PydanticPytorchModuleType
 from modalities.models.model import NNModel
 
 # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
@@ -54,9 +53,9 @@ class GPT2LLMConfig(BaseModel):
     attention_config: AttentionConfig
     activation_type: ActivationType
     weight_init: WeightInitailizationConfig
-    attention_norm: PydanticLayerNormIFType
-    ffn_norm: PydanticLayerNormIFType
-    lm_head_norm: PydanticLayerNormIFType
+    attention_norm: PydanticPytorchModuleType
+    ffn_norm: PydanticPytorchModuleType
+    lm_head_norm: PydanticPytorchModuleType
 
     @model_validator(mode="after")
     def validate_sizes(self) -> "GPT2LLMConfig":
@@ -173,8 +172,8 @@ def __init__(
         dropout: float,
         block_size: int,
         ffn_hidden: int,
-        attention_norm: LayerNormIF,
-        ffn_norm: LayerNormIF,
+        attention_norm: nn.Module,
+        ffn_norm: nn.Module,
     ):
         super().__init__()
         self.attention_norm = attention_norm
@@ -219,9 +218,9 @@ def __init__(
         attention_config: AttentionConfig,
         activation_type: ActivationType,
         weight_init: WeightInitailizationConfig,
-        attention_norm: LayerNormIF,
-        ffn_norm: LayerNormIF,
-        lm_head_norm: LayerNormIF,
+        attention_norm: nn.Module,
+        ffn_norm: nn.Module,
+        lm_head_norm: nn.Module,
     ):
         super().__init__()
         self.sample_key = sample_key
@@ -247,8 +246,8 @@ def __init__(
                             dropout=dropout,
                             block_size=block_size,
                             ffn_hidden=ffn_hidden,
-                            attention_norm=copy(attention_norm),
-                            ffn_norm=copy(ffn_norm),
+                            attention_norm=deepcopy(attention_norm),
+                            ffn_norm=deepcopy(ffn_norm),
                         )
                         for _ in range(n_layer)
                     ]
diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py
index 3257e1fa..ef31cb96 100644
--- a/src/modalities/registry/components.py
+++ b/src/modalities/registry/components.py
@@ -1,6 +1,7 @@
 from dataclasses import dataclass
 from typing import Type
 
+import torch.nn as nn
 from pydantic import BaseModel
 from torch.utils.data import BatchSampler, DistributedSampler
 from transformers import GPT2TokenizerFast
@@ -43,7 +44,7 @@
     ResultsSubscriberFactory,
 )
 from modalities.loss_functions import CLMCrossEntropyLoss
-from modalities.models.components.layer_norms import RMSLayerNorm, RMSLayerNormConfig, ZLayerNorm, ZLayerNormConfig
+from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig
 from modalities.models.gpt2.collator import GPT2LLMCollateFn
 from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig
 from modalities.models.huggingface.huggingface_models import (
@@ -157,5 +158,5 @@ class ComponentEntity:
     ),
     # layer norms
     ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig),
-    ComponentEntity("layer_norm", "z_layer_norm", ZLayerNorm, ZLayerNormConfig),
+    ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig),
 ]

From 6382d0ebd5fb6046727089735c3235000c595dc2 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 16:45:08 +0100
Subject: [PATCH 08/15] refactor: removed ZLayerNorm and resort to the original
 LayerNorm pytorch implementation without the need for a wrapper. Removed
 __copy__ overrides, as calling deepcopy on nn.Module already is capable of
 recursively copying a nn.Module. Introduced bias to RMSLayerNorm

---
 .../models/components/layer_norms.py          | 70 +++++--------------
 1 file changed, 19 insertions(+), 51 deletions(-)

diff --git a/src/modalities/models/components/layer_norms.py b/src/modalities/models/components/layer_norms.py
index 85f3647d..3d6a2bb6 100644
--- a/src/modalities/models/components/layer_norms.py
+++ b/src/modalities/models/components/layer_norms.py
@@ -3,21 +3,12 @@
 import torch
 import torch.nn as nn
 from pydantic import BaseModel, Field
-from torch.nn import functional as F
 
 from modalities.config.lookup_enum import LookupEnum
 
 
-class LayerNormIF(nn.Module):
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        raise NotImplementedError
-
-    def __copy__(self):
-        raise NotImplementedError
-
-
-class RMSLayerNorm(LayerNormIF):
-    def __init__(self, ndim: int, epsilon: float = 1e-6):
+class RMSLayerNorm(nn.Module):
+    def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-6):
         """
         Initialize the RMSNorm normalization layer.
         Original paper: https://arxiv.org/pdf/1910.07467.pdf
@@ -25,51 +16,26 @@ def __init__(self, ndim: int, epsilon: float = 1e-6):
 
         Args:
             ndim (int): The dimension of the input tensor.
-            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
-
-        Attributes:
-            eps (float): A small value added to the denominator for numerical stability.
-            weight (nn.Parameter): Learnable scaling parameter.
-
+            epsilon (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+            bias (bool, optional): If True, the layer will learn an additive bias. Default is True.
         """
         super().__init__()
         self.epsilon = epsilon
         self.weight = nn.Parameter(torch.ones(ndim))
+        if bias:
+            self.bias_tensor = nn.Parameter(torch.zeros(ndim))
+        else:
+            self.bias_tensor = None
 
     def _norm(self, x: torch.Tensor) -> torch.Tensor:
         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         output = self._norm(x.float()).type_as(x)
-        return output * self.weight
-
-    def __copy__(self):
-        copied_instance = RMSLayerNorm(self.weight.shape[0], self.epsilon)
-        return copied_instance
-
-
-class ZLayerNorm(LayerNormIF):
-    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
-
-    def __init__(self, ndim: int, bias: bool, epsilon: float = 1e-6):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(ndim))
-        self.bias_tensor = nn.Parameter(torch.zeros(ndim)) if bias else None
-        self.epsilon = epsilon
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        normalized_x = F.layer_norm(
-            input=x,
-            normalized_shape=self.weight.shape,
-            weight=self.weight,
-            bias=self.bias_tensor,
-            eps=self.epsilon,
-        )
-        return normalized_x
-
-    def __copy__(self):
-        copied_instance = ZLayerNorm(self.weight.shape[0], self.bias_tensor is not None, self.epsilon)
-        return copied_instance
+        if self.bias_tensor is None:
+            return output * self.weight
+        else:
+            return output * self.weight + self.bias_tensor
 
 
 class LayerNorms(LookupEnum):
@@ -78,15 +44,17 @@ class LayerNorms(LookupEnum):
     """
 
     RMSNorm = RMSLayerNorm
-    ZLayerNorm = ZLayerNorm
+    LayerNorm = nn.LayerNorm
 
 
-class ZLayerNormConfig(BaseModel):
-    ndim: Annotated[int, Field(strict=True, ge=1)]
-    bias: Annotated[bool, Field(default=True)]
-    epsilon: Annotated[float, Field(gt=0, default=1e-6)]
+class LayerNormConfig(BaseModel):
+    normalized_shape: Annotated[int, Field(strict=True, ge=1)]
+    eps: Annotated[float, Field(strict=True, gt=0, default=1e-6)]
+    elementwise_affine: Annotated[bool, Field(strict=True, default=True)]
+    bias: Annotated[bool, Field(strict=True, default=True)]
 
 
 class RMSLayerNormConfig(BaseModel):
     ndim: Annotated[int, Field(strict=True, ge=1)]
     epsilon: Annotated[float, Field(gt=0, default=1e-6)]
+    bias: Annotated[bool, Field(strict=True, default=True)]

From 1698a7d9fb8500c6cbb1f8689f0a548aa23a0ae1 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 16:45:47 +0100
Subject: [PATCH 09/15] refactor: renamed PydanticModelIFType to
 PydanticPytorchModuleType

---
 src/modalities/config/config.py | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py
index 7238ab2a..55d4b9af 100644
--- a/src/modalities/config/config.py
+++ b/src/modalities/config/config.py
@@ -21,7 +21,6 @@
 from modalities.dataloader.dataloader import LLMDataLoader
 from modalities.logging_broker.subscriber import MessageSubscriberIF
 from modalities.loss_functions import Loss
-from modalities.models.components.layer_norms import LayerNormIF
 from modalities.models.gpt2.collator import CollateFnIF
 from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support
 from modalities.util import get_date_of_run, parse_enum_by_name
@@ -53,7 +52,7 @@ def __get_pydantic_core_schema__(
 PydanticCheckpointingExecutionIFType = Annotated[
     CheckpointingExecutionIF, PydanticThirdPartyTypeIF(CheckpointingExecutionIF)
 ]
-PydanticModelIFType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
+PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
 PydanticTokenizerIFType = Annotated[PreTrainedTokenizerFast, PydanticThirdPartyTypeIF(PreTrainedTokenizerFast)]
 PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)]
 PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)]
@@ -62,7 +61,6 @@ def __get_pydantic_core_schema__(
 PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
 PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)]
 PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)]
-PydanticLayerNormIFType = Annotated[LayerNormIF, PydanticThirdPartyTypeIF(LayerNormIF)]
 
 
 class ProcessGroupBackendType(LookupEnum):
@@ -136,24 +134,24 @@ class CheckpointingConfig(BaseModel):
 
 class AdamWOptimizerConfig(BaseModel):
     lr: float
-    wrapped_model: PydanticModelIFType
+    wrapped_model: PydanticPytorchModuleType
 
 
 class CheckpointedOptimizerConfig(BaseModel):
     checkpointing: PydanticCheckpointingIFType
     checkpoint_path: Path
-    wrapped_model: PydanticModelIFType
+    wrapped_model: PydanticPytorchModuleType
     optimizer: PydanticOptimizerIFType
 
 
 class CheckpointedModelConfig(BaseModel):
     checkpointing: PydanticCheckpointingIFType
     checkpoint_path: Path
-    model: PydanticModelIFType
+    model: PydanticPytorchModuleType
 
 
 class FSDPWrappedModelConfig(BaseModel):
-    model: PydanticModelIFType
+    model: PydanticPytorchModuleType
     sync_module_states: bool
     mixed_precision_settings: MixedPrecisionSettings
     sharding_strategy: ShardingStrategy
@@ -305,7 +303,7 @@ class Paths(BaseModel):
 
 
 class ComponentsModel(BaseModel):
-    wrapped_model: PydanticModelIFType
+    wrapped_model: PydanticPytorchModuleType
     optimizer: PydanticOptimizerIFType
     loss_fn: PydanticLossIFType
     train_dataloader: PydanticLLMDataLoaderIFType
@@ -317,7 +315,7 @@ class ComponentsModel(BaseModel):
 
 
 class ComponentsInferenceModel(BaseModel):
-    wrapped_model: PydanticModelIFType
+    wrapped_model: PydanticPytorchModuleType
     cuda_env: CudaEnv
 
 

From 5d05fb398e322efab18f4162504d25930dc7909c Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 16:46:19 +0100
Subject: [PATCH 10/15] refactor: added config with latest layer norm changes

---
 .../config_example_mem_map_dataset.yaml       | 37 +++++++++++++++----
 1 file changed, 30 insertions(+), 7 deletions(-)

diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml
index 74cfd26f..02a4425f 100644
--- a/config_files/config_example_mem_map_dataset.yaml
+++ b/config_files/config_example_mem_map_dataset.yaml
@@ -150,28 +150,51 @@ model:
     weight_init:
       mean: 0.0
       std: 0.02
+    # attention_norm:
+    #   component_key: layer_norm
+    #   variant_key: layer_norm
+    #   config:
+    #     normalized_shape: ${model.config.n_embd}
+    #     elementwise_affine: true
+    #     bias: true
+    #     eps: 1e-5
+    # ffn_norm:
+    #   component_key: layer_norm
+    #   variant_key: layer_norm
+    #   config:
+    #     normalized_shape: ${model.config.n_embd}
+    #     elementwise_affine: true
+    #     bias: true
+    #     eps: 1e-5
+    # lm_head_norm:
+    #   component_key: layer_norm
+    #   variant_key: layer_norm
+    #   config:
+    #     normalized_shape: ${model.config.n_embd}
+    #     elementwise_affine: true
+    #     bias: true
+    #     eps: 1e-5
     attention_norm:
       component_key: layer_norm
-      variant_key: z_layer_norm
+      variant_key: rms_norm
       config:
         ndim: ${model.config.n_embd}
         bias: true
-        eps: 1e-5
+        epsilon: 1e-5
     ffn_norm:
       component_key: layer_norm
-      variant_key: z_layer_norm
+      variant_key: rms_norm
       config:
         ndim: ${model.config.n_embd}
         bias: true
-        eps: 1e-5
+        epsilon: 1e-5
     lm_head_norm:
       component_key: layer_norm
-      variant_key: z_layer_norm
+      variant_key: rms_norm
       config:
         ndim: ${model.config.n_embd}
         bias: true
-        eps: 1e-5
-
+        epsilon: 1e-5
 
 wrapped_model:
   component_key: model

From 56b88aef06c004b1b272109edf130dd7b928a185 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 17:12:52 +0100
Subject: [PATCH 11/15] test: removed LayerNorm tests and improved RMSLayerNorm
 tests

---
 tests/models/components/test_layer_norms.py | 30 +++------------------
 1 file changed, 4 insertions(+), 26 deletions(-)

diff --git a/tests/models/components/test_layer_norms.py b/tests/models/components/test_layer_norms.py
index 8e118a3b..5bb862eb 100644
--- a/tests/models/components/test_layer_norms.py
+++ b/tests/models/components/test_layer_norms.py
@@ -2,22 +2,16 @@
 import pytest
 import torch
 import torch.nn as nn
-from torch.nn import functional as F
 
-from modalities.models.components.layer_norms import RMSLayerNorm, ZLayerNorm
+from modalities.models.components.layer_norms import RMSLayerNorm
 
 
 @pytest.fixture
 def rms_layer_norm() -> RMSLayerNorm:
     norm = RMSLayerNorm(ndim=3, epsilon=1e-6)
     weight_tensor = torch.Tensor([1, 2, 3])
-    norm.weight = nn.Parameter(weight_tensor)
-    return norm
-
-
-@pytest.fixture
-def z_layer_norm() -> ZLayerNorm:
-    norm = ZLayerNorm(ndim=3, bias=True, epsilon=1e-6)
+    norm.gain = nn.Parameter(weight_tensor)
+    norm.bias_tensor = nn.Parameter(torch.ones(3))
     return norm
 
 
@@ -25,23 +19,7 @@ def test_rms_layer_norm_forward(rms_layer_norm):
     x = torch.Tensor([0.1, 0.2, 0.3])
     output = rms_layer_norm(x)
     ref_x = x / np.sqrt((0.1**2 + 0.2**2 + 0.3**2) / 3 + 1e-6)
-    ref_tensor = ref_x * rms_layer_norm.weight
-
-    assert output.shape == x.shape
-    assert all(output == ref_tensor)
-
-
-def test_z_layer_norm_forward(z_layer_norm):
-    x = torch.Tensor([0.1, 0.2, 0.3])
-    output = z_layer_norm(x)
-    ndim = z_layer_norm.weight.shape[0]
-    ref_tensor = F.layer_norm(
-        input=x,
-        normalized_shape=z_layer_norm.weight.shape,
-        weight=nn.Parameter(torch.ones(ndim)),
-        bias=nn.Parameter(torch.zeros(ndim)) if z_layer_norm.bias_tensor is not None else None,
-        eps=z_layer_norm.epsilon,
-    )
+    ref_tensor = ref_x * rms_layer_norm.gain + torch.tensor([1, 1, 1])
 
     assert output.shape == x.shape
     assert all(output == ref_tensor)

From 1558ef4f96916e1e42fd2e3487a4d3dea29e62e1 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 17:13:19 +0100
Subject: [PATCH 12/15] refactor: renamed RMSLayerNorm.weight to
 RMSLayerNorm.gain

---
 src/modalities/models/components/layer_norms.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/modalities/models/components/layer_norms.py b/src/modalities/models/components/layer_norms.py
index 3d6a2bb6..5bb8660b 100644
--- a/src/modalities/models/components/layer_norms.py
+++ b/src/modalities/models/components/layer_norms.py
@@ -21,7 +21,7 @@ def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-6):
         """
         super().__init__()
         self.epsilon = epsilon
-        self.weight = nn.Parameter(torch.ones(ndim))
+        self.gain = nn.Parameter(torch.ones(ndim))
         if bias:
             self.bias_tensor = nn.Parameter(torch.zeros(ndim))
         else:
@@ -33,9 +33,9 @@ def _norm(self, x: torch.Tensor) -> torch.Tensor:
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         output = self._norm(x.float()).type_as(x)
         if self.bias_tensor is None:
-            return output * self.weight
+            return output * self.gain
         else:
-            return output * self.weight + self.bias_tensor
+            return output * self.gain + self.bias_tensor
 
 
 class LayerNorms(LookupEnum):

From 3469d7c10c961f8ed2b1a37af40353a5f5803d52 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 17:15:20 +0100
Subject: [PATCH 13/15] refactor: changed default value for epsilon to 1e-5

---
 src/modalities/models/components/layer_norms.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/modalities/models/components/layer_norms.py b/src/modalities/models/components/layer_norms.py
index 5bb8660b..1c63f645 100644
--- a/src/modalities/models/components/layer_norms.py
+++ b/src/modalities/models/components/layer_norms.py
@@ -8,7 +8,7 @@
 
 
 class RMSLayerNorm(nn.Module):
-    def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-6):
+    def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
         """
         Initialize the RMSNorm normalization layer.
         Original paper: https://arxiv.org/pdf/1910.07467.pdf

From ba9cba6cae3c289e5a80766d9177c3cb733ae81d Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 18:52:18 +0100
Subject: [PATCH 14/15] fix: fixed renaming of attention -> attention_config

---
 src/modalities/models/gpt2/gpt2_model.py | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py
index b98137b9..ff92a8c8 100644
--- a/src/modalities/models/gpt2/gpt2_model.py
+++ b/src/modalities/models/gpt2/gpt2_model.py
@@ -158,13 +158,11 @@ class GPT2LLMConfig(BaseModel):
     bias: bool  # True: bias in Linears like GPT-2. False: a bit better and faster
     attention_config: AttentionConfig
     activation_type: ActivationType
-    weight_init: WeightInitailizationConfig
     attention_norm: PydanticPytorchModuleType
     ffn_norm: PydanticPytorchModuleType
     lm_head_norm: PydanticPytorchModuleType
     weight_init: WeightInitializationConfig
 
-
     @model_validator(mode="after")
     def validate_sizes(self) -> "GPT2LLMConfig":
         for param, param_name in zip(
@@ -185,7 +183,6 @@ def __init__(
         bias: bool,
         dropout: float,
         block_size: int,
-
     ):
         super().__init__()
         assert n_embd % n_head == 0
@@ -215,7 +212,7 @@ def __init__(
         # TODO: inject QKVTransforms from outside
         self.qkv_transforms = nn.ModuleList(
             transform_config.type_hint.value(**convert_base_model_config_to_dict(transform_config.config))
-            for transform_config in attention.qkv_transforms
+            for transform_config in attention_config.qkv_transforms
         )
 
         if not self.flash:
@@ -268,7 +265,7 @@ def __init__(self, n_embd: int, ffn_hidden: int, bias: bool, dropout: float):
         super().__init__()
         self.c_fc = nn.Linear(
             in_features=n_embd,
-            out_features=ffn_hidden,  # 4 * n_embd,
+            out_features=ffn_hidden,  # best practice: 4 * n_embd,
             bias=bias,
         )
         self.gelu = nn.GELU()
@@ -370,7 +367,7 @@ def __init__(
             raise TypeError(f"{poe_type} not supported")
 
         if poe_type is not PositionTypes.NOPE and RotaryTransform in [
-            config.type_hint.value for config in attention.qkv_transforms
+            config.type_hint.value for config in attention_config.qkv_transforms
         ]:
             raise ValueError('It is expected to use "RotaryTransform" together with "NOPE".')
 

From 3600d99a3a704d44bdc7fe169db945b82ea9de61 Mon Sep 17 00:00:00 2001
From: Max Luebbering <le1nux@users.noreply.github.com>
Date: Wed, 13 Mar 2024 18:53:05 +0100
Subject: [PATCH 15/15] feat: added ROPE to gpt2 model

---
 .../config_example_mem_map_dataset.yaml       | 31 +++++--------------
 1 file changed, 7 insertions(+), 24 deletions(-)

diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml
index 02a4425f..1175a731 100644
--- a/config_files/config_example_mem_map_dataset.yaml
+++ b/config_files/config_example_mem_map_dataset.yaml
@@ -136,6 +136,7 @@ model:
     sample_key: ${settings.referencing_keys.sample_key}
     prediction_key: ${settings.referencing_keys.prediction_key}
     block_size: ${settings.training.sequence_length}
+    poe_type: NOPE
     vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
     n_layer: 12
     n_head: 12
@@ -146,34 +147,16 @@ model:
     attention_config:
       attention_type: pytorch_flash_attention
       scaling_factor: 3
+      qkv_transforms:
+        - type_hint: RotaryTransform
+          config:
+            n_embd: ${model.config.n_embd}
+            n_head: ${model.config.n_head}
+            seq_length_dim: -2
     activation_type: gelu
     weight_init:
       mean: 0.0
       std: 0.02
-    # attention_norm:
-    #   component_key: layer_norm
-    #   variant_key: layer_norm
-    #   config:
-    #     normalized_shape: ${model.config.n_embd}
-    #     elementwise_affine: true
-    #     bias: true
-    #     eps: 1e-5
-    # ffn_norm:
-    #   component_key: layer_norm
-    #   variant_key: layer_norm
-    #   config:
-    #     normalized_shape: ${model.config.n_embd}
-    #     elementwise_affine: true
-    #     bias: true
-    #     eps: 1e-5
-    # lm_head_norm:
-    #   component_key: layer_norm
-    #   variant_key: layer_norm
-    #   config:
-    #     normalized_shape: ${model.config.n_embd}
-    #     elementwise_affine: true
-    #     bias: true
-    #     eps: 1e-5
     attention_norm:
       component_key: layer_norm
       variant_key: rms_norm