diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index fdb1c5f..bc96950 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,6 +7,7 @@ """ from functools import partial +import math import os import time from contextlib import nullcontext @@ -28,7 +29,6 @@ DataCollatorForLanguageModeling, LlamaConfig, LlamaForCausalLM, - get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -39,7 +39,7 @@ from torch.distributed import broadcast_object_list from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer - +from torch.optim.lr_scheduler import LambdaLR from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -105,6 +105,7 @@ class HvConfig(BaseConfig): skip_load_from_peers: bool = False world_rank: int galaxy_size: int + warmup_outerstep: int = 10 @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -189,6 +190,40 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_inner_steps: int, + warmup_outerstep: int | None, + num_cycles: float, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup(optimizer, config: Config): + lambda_lr = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + num_inner_steps=config.hv.local_steps, + warmup_outerstep=config.hv.warmup_outerstep, + num_cycles=0.5, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -280,8 +315,7 @@ def train(config: Config): def scheduler_fn(opt): return get_cosine_schedule_with_warmup( opt, - num_warmup_steps=config.warmup_steps, - num_training_steps=config.total_steps, + config=config, ) if config.hv is not None: