Skip to content

Commit

Permalink
Merge pull request #39 from Modalities/rope_embeddings
Browse files Browse the repository at this point in the history
Feat: RoPE Embeddings
  • Loading branch information
mali-git authored Mar 12, 2024
2 parents 0807555 + c6f349c commit 7054b5b
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 17 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,6 @@ tags
checkpoints
data
docs/source/generated
docs/source/api
docs/source/api
pyenv*
.devcontainer/*
1 change: 0 additions & 1 deletion config_files/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ training:
train_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size}
global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples}
do_apply_activation_checkpointing: True

checkpointing:
checkpointing_strategy:
type_hint: SaveKMostRecentCheckpointsStrategy
Expand Down
15 changes: 11 additions & 4 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ model:
component_key: model
variant_key: gpt2
config:
sample_key: "input_ids" # TODO reference this
prediction_key: "logits" # TODO reference this
block_size: 256 # TODO reference this (same as sequence length)
sample_key: ${settings.referencing_keys.sample_key}
poe_type: "NOPE"
block_size: ${settings.training.sequence_length}
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head: 4
Expand All @@ -197,8 +198,14 @@ model:
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: default_attention # pytorch_flash_attention
attention_type: pytorch_flash_attention
scaling_factor: 3
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head}
seq_length_dim: -2
activation: gelu
epsilon: 1e-5
weight_init:
Expand Down
8 changes: 8 additions & 0 deletions src/modalities/config/look_up_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum


class LookupEnum(Enum):
@classmethod
def _missing_(cls, value: str) -> type:
"""constructs Enum by member name, if not constructable by value"""
return cls.__dict__[value]
7 changes: 7 additions & 0 deletions src/modalities/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Any, Dict

from pydantic import BaseModel

def convert_base_model_config_to_dict(config: BaseModel) -> Dict[Any, Any]:
""""Converts non-recursively a Pydantic BaseModel to a dictionary."""
return {key: getattr(config, key) for key in config.model_dump().keys()}
170 changes: 159 additions & 11 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,108 @@
import math
from enum import Enum
from functools import partial
from typing import Annotated, Dict
from typing import Annotated, Dict, List, Tuple

import torch
import torch.nn as nn
import xformers.ops as xops
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, model_validator, validator
from torch.nn import functional as F

from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT


class PositionTypes(str, Enum):
ABSOLUTE = "ABSOLUTE"
NOPE = "NOPE"


class QueryKeyValueTransform(nn.Module):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass


class IdentityTransform(QueryKeyValueTransform):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return q, k, v


class RotaryTransform(QueryKeyValueTransform):
"""Implementation of Rotary Positioanl Embeddings
Source: https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
We added the corresponding code here, becauase there is a conflict with "@torch.jit.script" used in the
XFormers implementation and removed in this implementation.
"""

def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2):
super().__init__()
dim_model = n_embd // n_head
self.seq_length_dim = seq_length_dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None

def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

def _update_cos_sin_tables(self, x):
seq_len = x.shape[self.seq_length_dim]

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached

def apply_rotary_pos_emb(self, x, cos, sin):
# NOTE: This could probably be moved to Triton

# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[self.seq_length_dim], :]
sin = sin[:, :, : x.shape[self.seq_length_dim], :]

return (x * cos) + (self.rotate_half(x) * sin)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)
k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)

return q, k, v


class QueryKeyValueTransformType(Enum):
IdentityTransform = IdentityTransform
RotaryTransform = RotaryTransform


class AttentionType(str, Enum):
DEFAULT_ATTENTION = "default_attention"
PYTORCH_FLASH_ATTENTION = "pytorch_flash_attention"
Expand All @@ -25,18 +114,36 @@ class ActivationType(str, Enum):


class AttentionConfig(BaseModel):
class QueryKeyValueTransformConfig(BaseModel):
class IdentityTransformConfig(BaseModel):
pass

class RotaryTransformConfig(BaseModel):
n_embd: Annotated[int, Field(strict=True, ge=0)]
n_head: Annotated[int, Field(strict=True, ge=0)]
seq_length_dim: Annotated[int, Field(strict=True)]

@validator("type_hint", pre=True, always=True)
def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=QueryKeyValueTransformType)

type_hint: QueryKeyValueTransformType
config: RotaryTransformConfig | IdentityTransformConfig

attention_type: AttentionType
qkv_transforms: List[QueryKeyValueTransformConfig]
scaling_factor: Annotated[int, Field(strict=True, ge=1)]


class WeightInitailizationConfig(BaseModel):
class WeightInitializationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)]


class GPT2LLMConfig(BaseModel):
sample_key: str
prediction_key: str
poe_type: PositionTypes
block_size: Annotated[int, Field(strict=True, ge=1)]
vocab_size: Annotated[
int, Field(strict=True, ge=1)
Expand All @@ -51,7 +158,7 @@ class GPT2LLMConfig(BaseModel):
attention: AttentionConfig
activation: ActivationType
epsilon: Annotated[float, Field(strict=True, ge=0.0)]
weight_init: WeightInitailizationConfig
weight_init: WeightInitializationConfig

@model_validator(mode="after")
def validate_sizes(self) -> "GPT2LLMConfig":
Expand Down Expand Up @@ -85,14 +192,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

class CausalSelfAttention(nn.Module):
def __init__(
self, n_head: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int
self,
n_head: int,
n_embd: int,
attention: AttentionConfig,
bias: bool,
dropout: float,
block_size: int,
):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(
in_features=n_embd,
out_features=attention.scaling_factor * n_embd,
# 3, because we have queries, keys, and values
out_features=3 * n_embd,
bias=bias,
)

Expand All @@ -111,6 +225,12 @@ def __init__(
self.dropout = dropout
self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION

# TODO: inject QKVTransforms from outside
self.qkv_transforms = nn.ModuleList(
transform_config.type_hint.value(**convert_base_model_config_to_dict(transform_config.config))
for transform_config in attention.qkv_transforms
)

if not self.flash:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
Expand All @@ -127,6 +247,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

# TODO: move logic into a function
for qkv_transform in self.qkv_transforms:
q, k, v = qkv_transform(q, k, v)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
Expand Down Expand Up @@ -215,6 +339,7 @@ def __init__(
self,
sample_key: str,
prediction_key: str,
poe_type: PositionTypes,
block_size: int,
vocab_size: int,
n_layer: int,
Expand All @@ -226,20 +351,37 @@ def __init__(
attention: AttentionConfig,
activation: ActivationType,
epsilon: float,
weight_init: WeightInitailizationConfig,
weight_init: WeightInitializationConfig,
):
super().__init__()
self.sample_key = sample_key
self.prediction_key = prediction_key
self.block_size = block_size
self.poe_type = poe_type

assert vocab_size is not None
assert block_size is not None

# TODO: dependency injection
if poe_type is PositionTypes.ABSOLUTE:
wpe = nn.Embedding(num_embeddings=block_size, embedding_dim=n_embd)
elif poe_type is PositionTypes.NOPE:
# Using a pre-trained layer, requires to define a separate FSDP unit for the frozen layer c.f.
# https://github.com/huggingface/accelerate/issues/807
# wpe = nn.Embedding.from_pretrained(torch.zeros(block_size, n_embd))
wpe = nn.Identity()
else:
raise TypeError(f"{poe_type} not supported")

if poe_type is not PositionTypes.NOPE and RotaryTransform in [
config.type_hint.value for config in attention.qkv_transforms
]:
raise ValueError('It is expected to use "RotaryTransform" together with "NOPE".')

self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_embd),
wpe=nn.Embedding(num_embeddings=block_size, embedding_dim=n_embd),
wpe=wpe,
drop=nn.Dropout(dropout),
h=nn.ModuleList(
[
Expand Down Expand Up @@ -274,7 +416,7 @@ def __init__(
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=weight_init.mean, std=weight_init.std / math.sqrt(2 * n_layer))

def _init_weights(self, module: nn.Module, weight_init: WeightInitailizationConfig):
def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
Expand All @@ -291,8 +433,14 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso

# forward the GPT model itself
tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)

if self.poe_type is PositionTypes.ABSOLUTE:
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
tok_emb = tok_emb + pos_emb

# TODO: use drop out also without absolute position embedding?
x = self.transformer.drop(tok_emb)

for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
Expand Down
5 changes: 5 additions & 0 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from modalities.exceptions import TimeRecorderStateError
from modalities.running_env.fsdp.reducer import Reducer

def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum:
try:
return enum_type[name]
except KeyError:
raise ValidationError(f"Invalid {enum_type} member name: {name}")

def get_callback_interval_in_batches_per_rank(
callback_interval_in_samples: int, local_train_micro_batch_size: int, world_size: int, gradient_acc_steps: int
Expand Down
Loading

0 comments on commit 7054b5b

Please sign in to comment.