Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

soap topk compression #190

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third_party/gloo"]
path = third_party/gloo
url = https://github.com/facebookincubator/gloo.git
[submodule "third_party/optimizers"]
path = third_party/optimizers
url = [email protected]:PrimeIntellect-ai/optimizers.git
27 changes: 15 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.5.1",
"numpy",
"setuptools",
"numpy>=0",
"setuptools>=0",
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@74c94ee",
"einops",
"einops>=0",
"torchdata>=0.8.0",
"fsspec[gcs]>=2024.3.1",
"ninja",
"zstandard",
"pyarrow",
"toposolve",
"psutil",
"torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main",
"ninja>=0",
"zstandard>=0",
"pyarrow>=0",
"toposolve>=0",
"psutil>=0",
"torch-shampoo",
]

[project.optional-dependencies]


all = [
"wandb",
"wandb>=0",
"asyncio>=3.4.3",
"aiohttp>=3.10.5",
"requests>=2.32.3",
"lm-eval"
"lm-eval>=0",
]


Expand All @@ -45,4 +45,7 @@ allow-direct-references = true # allow direct references to git repos in depende
line-length = 120

[tool.uv]
dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0", "faker"]
dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0", "faker>=0"]

[tool.uv.sources]
torch-shampoo = { path = "third_party/optimizers", editable = true }
107 changes: 100 additions & 7 deletions src/zeroband/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from typing import Literal, TypeAlias
from pydantic_config import BaseConfig
import torch
from distributed_shampoo.shampoo_types import AdamGraftingConfig, EigenvalueCorrectedShampooPreconditionerConfig
from matrix_functions_types import DefaultEighEigenvectorConfig, TopKCompressionEigenvectorConfig

from distributed_shampoo import (
DefaultEigenvalueCorrectedShampooConfig,
DistributedShampoo,
FullyShardShampooConfig,
ShampooPT2CompileConfig,
)


class AdamConfig(BaseConfig):
type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type
type: Literal["adam"] = (
"adam" # the literal is used to distinguish between the different optimizers configuration in the union type
)
lr: float = 4e-4
weight_decay: float = 0.1
betas1: float = 0.9
betas2: float = 0.95


class SoapConfig(BaseConfig):
type: Literal["soap"] = "soap"
lr: float = 4e-4
Expand All @@ -25,21 +31,105 @@ class SoapConfig(BaseConfig):
max_preconditioner_dim: int = 8192
precondition_frequency: int = 100

topk: TopKCompressionEigenvectorConfig | None = None

eigen_stats: bool = False


class ShampooConfig(BaseConfig):
type: Literal["shampoo"] = "shampoo"
lr: float = 4e-4
weight_decay: float = 1e-05
betas1: float = 0.9
betas2: float = 0.95

precondition_frequency: int = 100
max_preconditioner_dim: int = 8192


OptimizersConfig: TypeAlias = AdamConfig | SoapConfig | ShampooConfig

# Constants for large matrix patterns in LLaMA
LLAMA_LARGE_MATRIX_PATTERNS = [
"attention.wq", # Attention matrices
"attention.wk",
"attention.wv",
"attention.wo",
"feed_forward.w1", # FFN matrices
"feed_forward.w2",
"feed_forward.w3",
]


def split_model_parameters(model: torch.nn.Module):
"""
Split model parameters into large matrices and other parameters.
Returns a tuple of (other_params, large_matrix_params)
"""
large_params = []
other_params = []

OptimizersConfig: TypeAlias = AdamConfig | SoapConfig
for name, param in model.named_parameters():
if any(pattern in name for pattern in LLAMA_LARGE_MATRIX_PATTERNS):
large_params.append(param)
else:
# Everything else (including tok_embeddings and output) goes here
other_params.append(param)

return other_params, large_params

def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:

def get_optimizer(model: torch.nn.Module, config: OptimizersConfig) -> torch.optim.Optimizer:
if isinstance(config, AdamConfig):
return torch.optim.AdamW(
params,
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.betas1, config.betas2),
)
elif isinstance(config, SoapConfig):
amortized_computation_config = DefaultEighEigenvectorConfig if config.topk is None else config.topk

other_params, large_params = split_model_parameters(model)

param_groups = [
{
"params": large_params,
"preconditioner_config": EigenvalueCorrectedShampooPreconditionerConfig(
amortized_computation_config=amortized_computation_config
),
"eigen_stats": config.eigen_stats,
},
{
"params": other_params,
"preconditioner_config": EigenvalueCorrectedShampooPreconditionerConfig(
amortized_computation_config=DefaultEighEigenvectorConfig
),
"eigen_stats": False,
},
]
# we only apply topk compression to large params

return DistributedShampoo(
param_groups,
lr=config.lr,
betas=(config.betas1, config.betas2),
epsilon=1e-12,
weight_decay=config.weight_decay,
max_preconditioner_dim=config.max_preconditioner_dim,
precondition_frequency=config.precondition_frequency,
use_decoupled_weight_decay=True,
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_config=EigenvalueCorrectedShampooPreconditionerConfig(
amortized_computation_config=DefaultEighEigenvectorConfig
),
distributed_config=FullyShardShampooConfig(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False),
)
elif isinstance(config, ShampooConfig):
return DistributedShampoo(
params,
model.parameters(),
lr=config.lr,
betas=(config.betas1, config.betas2),
epsilon=1e-12,
Expand All @@ -49,7 +139,10 @@ def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) ->
use_decoupled_weight_decay=True,
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
grafting_config=AdamGraftingConfig(
beta2=0.999,
epsilon=1e-08,
),
distributed_config=FullyShardShampooConfig(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False),
)
Expand Down
30 changes: 28 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

import torch.distributed as dist
from distributed_shampoo import DistributedShampoo
from zeroband import utils
from zeroband.diloco import Diloco
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.models.llama.model import create_block_mask_from_seqlens
from zeroband.config import Config #, MemoryProfilerConfig
from zeroband.config import Config # , MemoryProfilerConfig
from zeroband.optimizers import get_optimizer

from zeroband.utils import (
Expand All @@ -39,6 +40,7 @@
from zeroband.checkpoint import CkptManager, TrainingProgress
from zeroband.lr_scheduler import get_scheduler


def log_hash_training_state(
config: Config,
model: torch.nn.Module,
Expand Down Expand Up @@ -164,7 +166,7 @@ def train(config: Config):
logger.debug("model fsdped")

# Setup optimizers
inner_optimizer = get_optimizer(model.parameters(), config.optim.optim)
inner_optimizer = get_optimizer(model, config.optim.optim)

diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None

Expand Down Expand Up @@ -361,6 +363,30 @@ def train(config: Config):
"time": time.time(),
}

if (
isinstance(inner_optimizer, DistributedShampoo)
and training_progress.step % config.optim.optim.precondition_frequency == 0
and training_progress.step > 0
and world_info.rank == 0
):
logger.info(f"step {training_progress.step} preconditioning")
eigen_stats = inner_optimizer.eigenvector_stats(key_to_param=model.named_parameters())
# 1/0
og_total_rank = 0
effective_total_rank = 0

for param_name, param_stats in eigen_stats.items():
if param_stats is not None:
for key, val in param_stats.items():
log_stats = val.log_stats()
for sub_key, sub_val in log_stats.items():
metrics[f"eigenvalue_stats/{param_name}/{key}/{sub_key}"] = sub_val

og_total_rank += val.og_rank
effective_total_rank += val.effective_rank

metrics["total_compression"] = 1 - effective_total_rank / og_total_rank if og_total_rank > 0 else 0

if config.optim.z_loss:
metrics["z_loss"] = z_loss_batch.item()

Expand Down
10 changes: 8 additions & 2 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,19 @@ def test_packing(packing: bool):
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg])





@pytest.mark.parametrize("diloco", [False, True])
def test_soap(diloco: bool):
@pytest.mark.parametrize("topk_compression", [None, 5, 0.1])
def test_soap(diloco: bool, topk_compression: int | None):
num_gpus = [1, 2] if diloco else [2, 1]

_test_multi_gpu(
num_gpus,
"debug/diloco.toml" if diloco else "debug/normal.toml",
extra_args=["--optim.optim.precondition_frequency", "1"],
extra_args=["--optim.optim.precondition_frequency", "1"]
+ (["--optim.optim.topk.topk_compression", str(topk_compression)] if topk_compression is not None else []),
diloco=diloco,
)

Expand Down
1 change: 1 addition & 0 deletions third_party/optimizers
Submodule optimizers added at 6b35d4
38 changes: 25 additions & 13 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.