-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from Modalities/rms_norm
RMS norm implementation
- Loading branch information
Showing
9 changed files
with
171 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Annotated | ||
|
||
import torch | ||
import torch.nn as nn | ||
from pydantic import BaseModel, Field | ||
|
||
from modalities.config.lookup_enum import LookupEnum | ||
|
||
|
||
class RMSLayerNorm(nn.Module): | ||
def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5): | ||
""" | ||
Initialize the RMSNorm normalization layer. | ||
Original paper: https://arxiv.org/pdf/1910.07467.pdf | ||
Source code adopted from https://github.com/facebookresearch/llama/blob/a0a4da8b497c566403941ceec47c2512ecf9dd20/llama/model.py#L34C1-L77C36 | ||
Args: | ||
ndim (int): The dimension of the input tensor. | ||
epsilon (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | ||
bias (bool, optional): If True, the layer will learn an additive bias. Default is True. | ||
""" | ||
super().__init__() | ||
self.epsilon = epsilon | ||
self.gain = nn.Parameter(torch.ones(ndim)) | ||
if bias: | ||
self.bias_tensor = nn.Parameter(torch.zeros(ndim)) | ||
else: | ||
self.bias_tensor = None | ||
|
||
def _norm(self, x: torch.Tensor) -> torch.Tensor: | ||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
output = self._norm(x.float()).type_as(x) | ||
if self.bias_tensor is None: | ||
return output * self.gain | ||
else: | ||
return output * self.gain + self.bias_tensor | ||
|
||
|
||
class LayerNorms(LookupEnum): | ||
""" | ||
An enumeration of the different layer normalization techniques. | ||
""" | ||
|
||
RMSNorm = RMSLayerNorm | ||
LayerNorm = nn.LayerNorm | ||
|
||
|
||
class LayerNormConfig(BaseModel): | ||
normalized_shape: Annotated[int, Field(strict=True, ge=1)] | ||
eps: Annotated[float, Field(strict=True, gt=0, default=1e-6)] | ||
elementwise_affine: Annotated[bool, Field(strict=True, default=True)] | ||
bias: Annotated[bool, Field(strict=True, default=True)] | ||
|
||
|
||
class RMSLayerNormConfig(BaseModel): | ||
ndim: Annotated[int, Field(strict=True, ge=1)] | ||
epsilon: Annotated[float, Field(gt=0, default=1e-6)] | ||
bias: Annotated[bool, Field(strict=True, default=True)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Oops, something went wrong.