From f6748dd8cb7579cb5839f28a3b53558459c8fac6 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:32:35 +0200 Subject: [PATCH 01/15] fix torch compile log act (#23) * fix renaming logic for key * fix stuff * fix exploding norm * remove print --- open_diloco/train_fsdp.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index ab4efe2..3d17eee 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,6 +7,7 @@ """ from functools import partial +import math import os import time from contextlib import nullcontext @@ -26,7 +27,6 @@ DataCollatorForLanguageModeling, LlamaConfig, LlamaForCausalLM, - get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -46,6 +46,7 @@ ) from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer from open_diloco.utils import WandbLogger, DummyLogger +from torch.optim.lr_scheduler import LambdaLR from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -173,6 +174,27 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): + lambda_lr = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=0.5, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -254,6 +276,7 @@ def scheduler_fn(opt): opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_steps, + num_inner_steps=config.hv.local_steps, ) if config.hv is not None: From 9d03959d7d79849014408aecbbf1e5670ab52b8d Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 27 Jul 2024 22:03:37 +0000 Subject: [PATCH 02/15] add warmup steps --- open_diloco/train_fsdp.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 3d17eee..0262edb 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -91,6 +91,7 @@ class HvConfig(BaseConfig): world_rank: int galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker + warmup_outerstep: int = 10 @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -175,8 +176,18 @@ def get_model(config: Config) -> LlamaForCausalLM: def _get_cosine_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_inner_steps: int, + warmup_outerstep: int | None, + num_cycles: float, + min_lr_rate: float = 0.0, ): + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) @@ -185,11 +196,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda( return max(0, factor) -def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): +def get_cosine_schedule_with_warmup(optimizer, config: Config): lambda_lr = partial( _get_cosine_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + num_inner_steps=config.hv.local_steps, + warmup_outerstep=config.hv.warmup_outerstep, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) @@ -274,9 +287,7 @@ def train(config: Config): def scheduler_fn(opt): return get_cosine_schedule_with_warmup( opt, - num_warmup_steps=config.warmup_steps, - num_training_steps=config.total_steps, - num_inner_steps=config.hv.local_steps, + config=config, ) if config.hv is not None: From ac542182f120e58e7d227ab979d2d38dc11b842b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 2 Aug 2024 14:29:34 +0000 Subject: [PATCH 03/15] do not update lr scheduler during warmup --- open_diloco/train_fsdp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 0262edb..70fff32 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -185,7 +185,11 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_cycles: float, min_lr_rate: float = 0.0, ): - if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + if ( + warmup_outerstep is not None + and current_step > num_warmup_steps + and current_step % num_inner_steps < warmup_outerstep + ): return 0 if current_step < num_warmup_steps: From ce6f82bbdd5d4ad49ccafbdb9cf431c474d018a7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 2 Aug 2024 14:38:58 +0000 Subject: [PATCH 04/15] do not update lr scheduler during warmup --- open_diloco/train_fsdp.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 70fff32..5e23bd7 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -185,15 +185,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda( num_cycles: float, min_lr_rate: float = 0.0, ): - if ( - warmup_outerstep is not None - and current_step > num_warmup_steps - and current_step % num_inner_steps < warmup_outerstep - ): - return 0 - if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) + + if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: + return 0 + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate From e91cbda7aabdba13ddb4d4ef91c92f15c44e6b38 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 2 Aug 2024 16:31:00 +0000 Subject: [PATCH 05/15] add outer lr scheduler --- open_diloco/hivemind_diloco.py | 8 ++++++- open_diloco/train_fsdp.py | 39 ++++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/open_diloco/hivemind_diloco.py b/open_diloco/hivemind_diloco.py index 308608b..0031b8d 100644 --- a/open_diloco/hivemind_diloco.py +++ b/open_diloco/hivemind_diloco.py @@ -334,6 +334,7 @@ def __init__( inner_optimizer: OptimizerFactory, params: Optional[Union[Parameters, ParamGroups]] = None, scheduler: Optional[SchedulerFactory] = None, + outer_scheduler: Optional[SchedulerFactory] = None, averager_opts: Optional[dict] = None, grad_compression: CompressionBase = NoCompression(), tracker_opts: Optional[dict] = None, @@ -365,7 +366,7 @@ def __init__( # since we have two optimizers, we need to persist the params to a list self.num_inner_steps = num_inner_steps - for opt_or_scheduler in [outer_optimizer, scheduler]: + for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]: if not (callable(opt_or_scheduler) or opt_or_scheduler is None): raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable") @@ -405,6 +406,8 @@ def __init__( ) self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression) + self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) + def _check_kwargs(self, kwargs) -> None: """DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs. This function raise an error if some kwargs are not supported""" @@ -555,6 +558,9 @@ def step( if self.tracker.ready_to_update_epoch: self._update_global_epoch() + if self.outer_scheduler is not None: + self.outer_scheduler.step() + return loss def _compute_schema_hash(self) -> int: diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 5e23bd7..4e879be 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -92,6 +92,8 @@ class HvConfig(BaseConfig): galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker warmup_outerstep: int = 10 + outer_lr_min: float = 0.3 + outer_scheduler: bool = False @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -180,17 +182,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda( *, num_warmup_steps: int, num_training_steps: int, - num_inner_steps: int, - warmup_outerstep: int | None, num_cycles: float, min_lr_rate: float = 0.0, ): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) - if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep: - return 0 - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate @@ -202,13 +199,36 @@ def get_cosine_schedule_with_warmup(optimizer, config: Config): _get_cosine_schedule_with_warmup_lr_lambda, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_steps, - num_inner_steps=config.hv.local_steps, - warmup_outerstep=config.hv.warmup_outerstep, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) +def _get_lr_outer( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return 1 + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_lr_outer(optimizer, config: Config): + lambda_lr = partial( + _get_lr_outer, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -291,6 +311,9 @@ def scheduler_fn(opt): config=config, ) + def outer_scheduler_fn(opt): + return get_lr_outer(opt, config=config) + if config.hv is not None: if config.ckpt.resume: # We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer @@ -316,6 +339,7 @@ def scheduler_fn(opt): outer_optimizer=outer_optimizer, inner_optimizer=inner_optimizer, scheduler=None, + outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None, params=model.parameters(), delay_optimizer_step=False, delay_grad_averaging=False, @@ -435,6 +459,7 @@ def scheduler_fn(opt): scaler.update() scheduler.step() + optimizer.zero_grad() if config.hv is not None: From d0dc44c76314459668650b8d3d97c1697be1557d Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 5 Aug 2024 15:17:36 +0000 Subject: [PATCH 06/15] add div by 4 --- open_diloco/train_fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 4e879be..c7566fd 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -224,7 +224,8 @@ def get_lr_outer(optimizer, config: Config): lambda_lr = partial( _get_lr_outer, num_warmup_steps=config.warmup_steps, - num_training_steps=config.total_steps, + # num_training_steps=config.total_steps, + num_training_steps=config.total_steps / 4, ) return LambdaLR(optimizer, lambda_lr, -1) From 532af191af85e0c70dcc85943a2686c704cfe36f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 5 Aug 2024 15:34:59 +0000 Subject: [PATCH 07/15] support none outer lr --- open_diloco/hivemind_diloco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/hivemind_diloco.py b/open_diloco/hivemind_diloco.py index 0031b8d..74d94d5 100644 --- a/open_diloco/hivemind_diloco.py +++ b/open_diloco/hivemind_diloco.py @@ -406,7 +406,7 @@ def __init__( ) self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression) - self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) + self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) if outer_scheduler else None def _check_kwargs(self, kwargs) -> None: """DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs. From da41f5b9cf6ee353e5c438b68126aa0917a0afc7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 6 Aug 2024 12:26:24 +0000 Subject: [PATCH 08/15] fix outer lr schedulng --- open_diloco/train_fsdp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index c7566fd..668a554 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -209,13 +209,14 @@ def _get_lr_outer( *, num_warmup_steps: int, num_training_steps: int, + num_cycles: float, min_lr_rate: float = 0.0, ): if current_step < num_warmup_steps: return 1 progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - factor = 0.5 * (1.0 + math.cos(math.pi * 2.0 * progress)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) factor = factor * (1 - min_lr_rate) + min_lr_rate return max(0, factor) @@ -226,6 +227,7 @@ def get_lr_outer(optimizer, config: Config): num_warmup_steps=config.warmup_steps, # num_training_steps=config.total_steps, num_training_steps=config.total_steps / 4, + num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) From 16357ca40540d58dac3459c05bfa2bb97e22eba3 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 6 Aug 2024 16:53:55 +0000 Subject: [PATCH 09/15] fix outer lr schedulng --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 668a554..6672f37 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -226,7 +226,7 @@ def get_lr_outer(optimizer, config: Config): _get_lr_outer, num_warmup_steps=config.warmup_steps, # num_training_steps=config.total_steps, - num_training_steps=config.total_steps / 4, + num_training_steps=config.total_steps / 10, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) From 4ea27359befdbc024a41a69942de191c6118f480 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 7 Aug 2024 15:25:44 +0000 Subject: [PATCH 10/15] update divide --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 6672f37..668a554 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -226,7 +226,7 @@ def get_lr_outer(optimizer, config: Config): _get_lr_outer, num_warmup_steps=config.warmup_steps, # num_training_steps=config.total_steps, - num_training_steps=config.total_steps / 10, + num_training_steps=config.total_steps / 4, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) From 1e43b47efbd02cd2bebbdb11b62d8e03438ff549 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 8 Aug 2024 08:49:44 +0000 Subject: [PATCH 11/15] add ckpt for outer schedler --- open_diloco/ckpt_utils.py | 9 +++++++++ open_diloco/train_fsdp.py | 1 + 2 files changed, 10 insertions(+) diff --git a/open_diloco/ckpt_utils.py b/open_diloco/ckpt_utils.py index 261b7c4..97688c6 100644 --- a/open_diloco/ckpt_utils.py +++ b/open_diloco/ckpt_utils.py @@ -40,6 +40,7 @@ def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, + outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, outer_optimizer: torch.optim.Optimizer | None = None, scaler: torch.cuda.amp.GradScaler | None = None, loss: float | None = None, @@ -81,6 +82,8 @@ def save_checkpoint( # 2. Save global states global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0} + if outer_scheduler is not None: + global_state_dict["outer_scheduler"] = outer_scheduler.state_dict() if outer_optimizer is not None: global_state_dict["outer_optimizer"] = outer_optimizer.state_dict() if scaler is not None: @@ -95,6 +98,7 @@ def load_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, + outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, outer_optimizer: torch.optim.Optimizer | None = None, scaler: torch.cuda.amp.GradScaler | None = None, data_loader: StatefulDataLoader | None = None, @@ -139,8 +143,13 @@ def load_checkpoint( if scheduler is not None: scheduler.load_state_dict(global_state_dict["scheduler"]) optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0] + if outer_optimizer is not None: outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"]) + if outer_scheduler is not None: + outer_scheduler.load_state_dict(global_state_dict["outer_scheduler"]) + outer_optimizer.param_groups[0]["lr"] = outer_scheduler.get_last_lr()[0] + if scaler is not None: scaler.load_state_dict(global_state_dict["scaler"]) return global_state_dict["loss"] diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 668a554..a4faa6f 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -539,6 +539,7 @@ def outer_scheduler_fn(opt): model=model, optimizer=optimizer.inner_optimizer, scheduler=scheduler, + outer_scheduler=optimizer.outer_scheduler, outer_optimizer=optimizer.state_averager.optimizer, loss=loss_batch.item(), scaler=scaler, From 7328d7d245c7441edf882710d25d95004ff2d30f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 8 Aug 2024 09:41:10 +0000 Subject: [PATCH 12/15] add ckpt for outer schedler --- open_diloco/train_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index a4faa6f..e5fd3c8 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -373,6 +373,7 @@ def outer_scheduler_fn(opt): model=model, optimizer=optimizer.inner_optimizer, scheduler=scheduler, + outer_scheduler=optimizer.outer_scheduler, outer_optimizer=optimizer.state_averager.optimizer, scaler=scaler, data_loader=train_dataloader, From 392a0159aa6e7f966cef893911496c89dc42a686 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 11 Aug 2024 14:32:58 +0000 Subject: [PATCH 13/15] change div --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index e5fd3c8..09b43cd 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -226,7 +226,7 @@ def get_lr_outer(optimizer, config: Config): _get_lr_outer, num_warmup_steps=config.warmup_steps, # num_training_steps=config.total_steps, - num_training_steps=config.total_steps / 4, + num_training_steps=config.total_steps / 5, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1) From b2b3119a84e8f8d99e71fd196cd3cba20ac14d8f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 12 Aug 2024 19:14:09 +0000 Subject: [PATCH 14/15] remove min olr --- open_diloco/train_fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 09b43cd..9b5d76a 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -92,7 +92,6 @@ class HvConfig(BaseConfig): galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker warmup_outerstep: int = 10 - outer_lr_min: float = 0.3 outer_scheduler: bool = False @model_validator(mode="before") From bf21c1302b10a9ce70a42dfd13b61aaf58297af7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 13 Aug 2024 16:46:43 +0000 Subject: [PATCH 15/15] remove divi --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 9b5d76a..4383e7c 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -225,7 +225,7 @@ def get_lr_outer(optimizer, config: Config): _get_lr_outer, num_warmup_steps=config.warmup_steps, # num_training_steps=config.total_steps, - num_training_steps=config.total_steps / 5, + num_training_steps=config.total_steps, num_cycles=0.5, ) return LambdaLR(optimizer, lambda_lr, -1)