Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat outer lr scheduler #20

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions open_diloco/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_optimizer: torch.optim.Optimizer | None = None,
scaler: torch.cuda.amp.GradScaler | None = None,
loss: float | None = None,
Expand Down Expand Up @@ -81,6 +82,8 @@ def save_checkpoint(

# 2. Save global states
global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0}
if outer_scheduler is not None:
global_state_dict["outer_scheduler"] = outer_scheduler.state_dict()
if outer_optimizer is not None:
global_state_dict["outer_optimizer"] = outer_optimizer.state_dict()
if scaler is not None:
Expand All @@ -95,6 +98,7 @@ def load_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_optimizer: torch.optim.Optimizer | None = None,
scaler: torch.cuda.amp.GradScaler | None = None,
data_loader: StatefulDataLoader | None = None,
Expand Down Expand Up @@ -139,8 +143,13 @@ def load_checkpoint(
if scheduler is not None:
scheduler.load_state_dict(global_state_dict["scheduler"])
optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0]

if outer_optimizer is not None:
outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"])
if outer_scheduler is not None:
outer_scheduler.load_state_dict(global_state_dict["outer_scheduler"])
outer_optimizer.param_groups[0]["lr"] = outer_scheduler.get_last_lr()[0]

if scaler is not None:
scaler.load_state_dict(global_state_dict["scaler"])
return global_state_dict["loss"]
Expand Down
8 changes: 7 additions & 1 deletion open_diloco/hivemind_diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(
inner_optimizer: OptimizerFactory,
params: Optional[Union[Parameters, ParamGroups]] = None,
scheduler: Optional[SchedulerFactory] = None,
outer_scheduler: Optional[SchedulerFactory] = None,
averager_opts: Optional[dict] = None,
grad_compression: CompressionBase = NoCompression(),
tracker_opts: Optional[dict] = None,
Expand Down Expand Up @@ -365,7 +366,7 @@ def __init__(
# since we have two optimizers, we need to persist the params to a list
self.num_inner_steps = num_inner_steps

for opt_or_scheduler in [outer_optimizer, scheduler]:
for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]:
if not (callable(opt_or_scheduler) or opt_or_scheduler is None):
raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable")

Expand Down Expand Up @@ -405,6 +406,8 @@ def __init__(
)
self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression)

self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) if outer_scheduler else None

def _check_kwargs(self, kwargs) -> None:
"""DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs.
This function raise an error if some kwargs are not supported"""
Expand Down Expand Up @@ -555,6 +558,9 @@ def step(
if self.tracker.ready_to_update_epoch:
self._update_global_epoch()

if self.outer_scheduler is not None:
self.outer_scheduler.step()

return loss

def _compute_schema_hash(self) -> int:
Expand Down
70 changes: 67 additions & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from functools import partial
import math
import os
import time
from contextlib import nullcontext
Expand All @@ -26,7 +27,6 @@
DataCollatorForLanguageModeling,
LlamaConfig,
LlamaForCausalLM,
get_cosine_schedule_with_warmup,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
Expand All @@ -46,6 +46,7 @@
)
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
from open_diloco.utils import WandbLogger, DummyLogger
from torch.optim.lr_scheduler import LambdaLR

from hivemind.dht.dht import DHT
from hivemind.utils.networking import log_visible_maddrs
Expand Down Expand Up @@ -90,6 +91,8 @@ class HvConfig(BaseConfig):
world_rank: int
galaxy_size: int
fail_rank_drop: bool = False # fail if we lose a diloco worker
warmup_outerstep: int = 10
outer_scheduler: bool = False

@model_validator(mode="before")
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +176,61 @@ def get_model(config: Config) -> LlamaForCausalLM:
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)


def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))

progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)


def get_cosine_schedule_with_warmup(optimizer, config: Config):
lambda_lr = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
num_cycles=0.5,
)
return LambdaLR(optimizer, lambda_lr, -1)


def _get_lr_outer(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return 1

progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)


def get_lr_outer(optimizer, config: Config):
lambda_lr = partial(
_get_lr_outer,
num_warmup_steps=config.warmup_steps,
# num_training_steps=config.total_steps,
num_training_steps=config.total_steps,
num_cycles=0.5,
)
return LambdaLR(optimizer, lambda_lr, -1)


def train(config: Config):
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
local_rank = int(os.environ["LOCAL_RANK"])
Expand Down Expand Up @@ -252,10 +310,12 @@ def train(config: Config):
def scheduler_fn(opt):
return get_cosine_schedule_with_warmup(
opt,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
config=config,
)

def outer_scheduler_fn(opt):
return get_lr_outer(opt, config=config)

if config.hv is not None:
if config.ckpt.resume:
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
Expand All @@ -281,6 +341,7 @@ def scheduler_fn(opt):
outer_optimizer=outer_optimizer,
inner_optimizer=inner_optimizer,
scheduler=None,
outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None,
params=model.parameters(),
delay_optimizer_step=False,
delay_grad_averaging=False,
Expand Down Expand Up @@ -311,6 +372,7 @@ def scheduler_fn(opt):
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
outer_scheduler=optimizer.outer_scheduler,
outer_optimizer=optimizer.state_averager.optimizer,
scaler=scaler,
data_loader=train_dataloader,
Expand Down Expand Up @@ -400,6 +462,7 @@ def scheduler_fn(opt):
scaler.update()

scheduler.step()

optimizer.zero_grad()

if config.hv is not None:
Expand Down Expand Up @@ -476,6 +539,7 @@ def scheduler_fn(opt):
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
outer_scheduler=optimizer.outer_scheduler,
outer_optimizer=optimizer.state_averager.optimizer,
loss=loss_batch.item(),
scaler=scaler,
Expand Down
Loading