Skip to content

Commit

Permalink
add warmup steps
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 2, 2024
1 parent 7a334c1 commit 852b3c5
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class HvConfig(BaseConfig):
skip_load_from_peers: bool = False
world_rank: int
galaxy_size: int
warmup_outerstep: int = 10

@model_validator(mode="before")
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -190,8 +191,18 @@ def get_model(config: Config) -> LlamaForCausalLM:


def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_inner_steps: int,
warmup_outerstep: int | None,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
return 0

if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
Expand All @@ -200,11 +211,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
return max(0, factor)


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
def get_cosine_schedule_with_warmup(optimizer, config: Config):
lambda_lr = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
num_inner_steps=config.hv.local_steps,
warmup_outerstep=config.hv.warmup_outerstep,
num_cycles=0.5,
)
return LambdaLR(optimizer, lambda_lr, -1)
Expand Down Expand Up @@ -301,9 +314,7 @@ def train(config: Config):
def scheduler_fn(opt):
return get_cosine_schedule_with_warmup(
opt,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
num_inner_steps=config.hv.local_steps,
config=config,
)

if config.hv is not None:
Expand Down

0 comments on commit 852b3c5

Please sign in to comment.