Skip to content

Commit

Permalink
Merge branch 'fused-kernels' into 'main'
Browse files Browse the repository at this point in the history
perf: Replace transformers model with custom implementation and introduce fused kernels for attention and swiglu

See merge request opengpt-x1/llmgym!2
  • Loading branch information
fromm-m committed Nov 6, 2023
2 parents 325633b + 55a9bb7 commit 1578f4e
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 235 deletions.
14 changes: 13 additions & 1 deletion config_files/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@ runner:
model:
target_class: llm_gym.gpt2.gpt2_model.GPT2LLM
prediction_publication_key: logits
# local_rank: ${oc.env:LOCAL_RANK}
config:
block_size: 1024
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
attention_type: "pytorch_flash_attention"
scaling_factor: 3
activation: "fused_swiglu"
epsilon: 1e-5
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ dependencies = [
"SentencePiece",
"accelerate",
"rich",
"xformers",
"hydra-core",
"pydantic",
"click",
"click_pathlib",
"xformers"
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion src/llm_gym/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.optim.lr_scheduler import StepLR
from torch.utils.data.distributed import DistributedSampler


from llm_gym.callbacks.batch_progress_callbacks import DummyProgressCallback, PrintProgressCallback
from llm_gym.callbacks.results_callbacks import DummyResultsCallback, ResultsCallback
from llm_gym.checkpointing.checkpointing import Checkpointing
Expand All @@ -24,7 +25,7 @@
from llm_gym.forward_pass import ModelInferenceComponent
from llm_gym.fsdp.fsdp_runner import Runner
from llm_gym.gpt2.collator import GPT2LLMCollator, LMWikiBookCorpusDatasetFactory
from llm_gym.gpt2.gpt2_model import GPT2LLM
from llm_gym.gpt2.gpt2_model import GPT2LLM, GPTConfig
from llm_gym.gym import Gym
from llm_gym.loss_functions import Loss
from llm_gym.trainer import Trainer
Expand Down
14 changes: 8 additions & 6 deletions src/llm_gym/config/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import traceback
from enum import Enum
from typing import Annotated, List
from typing import Annotated

from hydra._internal.utils import _locate
from pydantic import BaseModel, DirectoryPath, conint
from pydantic.functional_validators import AfterValidator

from llm_gym.gpt2.gpt2_model import GPTConfig


def validate_class_path(path: str):
try:
Expand All @@ -17,7 +18,7 @@ def validate_class_path(path: str):
return path


TargetPath = Annotated[str, AfterValidator(validate_class_path)]
ClassPath = Annotated[str, AfterValidator(validate_class_path)]


class ProcessGroupBackendEnum(str, Enum):
Expand All @@ -34,18 +35,19 @@ class TrainingConfig(BaseModel):


class ModelConfig(BaseModel):
target_class: TargetPath
target_class: ClassPath
prediction_publication_key: str
config: GPTConfig


class LossConfig(BaseModel):
target_class: TargetPath
target_class: ClassPath
target_subscription_key: str
prediction_subscription_key: str


class RunnerConfig(BaseModel):
target_class: TargetPath
target_class: ClassPath
process_group_backend: ProcessGroupBackendEnum


Expand Down
209 changes: 0 additions & 209 deletions src/llm_gym/fsdp.py

This file was deleted.

1 change: 0 additions & 1 deletion src/llm_gym/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

11 changes: 7 additions & 4 deletions src/llm_gym/fsdp/fsdp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch.distributed as dist
from llm_gym.config.config import ProcessGroupBackendEnum
from llm_gym.env_utils import bfSixteen, has_bfloat_support
from llm_gym.gpt2.gpt2_model import NNModel
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from llm_gym.gpt2.gpt2_model import NNModel, Block
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools


class Runner(ABC):
Expand Down Expand Up @@ -36,10 +37,12 @@ def wrap(self, model: NNModel, local_rank: int) -> FSDP:
else:
mp_policy = None # defaults to fp32

transformer_auto_wrapper_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block,})

# model is on CPU before input to FSDP
model = FSDP(
model,
auto_wrap_policy=None,
auto_wrap_policy=transformer_auto_wrapper_policy,
mixed_precision=mp_policy,
sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device(),
Expand Down
Loading

0 comments on commit 1578f4e

Please sign in to comment.