Skip to content

Commit

Permalink
Merge pull request #71 from argonne-lcf/aurora-post-at
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 authored Dec 23, 2024
2 parents b8007f4 + 1504bd0 commit 587aafd
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 130 deletions.
4 changes: 2 additions & 2 deletions ALCF/helpers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,8 @@ buildDSconfig() {
# export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}"
export DS_CONFIG="${WORKING_DIR}/ds-configs/ds_stage${ZERO_STAGE}_mb${MICRO_BATCH}_gb${GLOBAL_BATCH}_pp${PP}_${DTYPE}.json"
mkdir -p "$(dirname "${DS_CONFIG}")"
echo "DS_CONFIG: ${DS_CONFIG}"
printf "ZS: %s, MB: %s, GB: %s, PP: %s, DTYPE: %s" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}"
printf "DS_CONFIG: %s\n" "${DS_CONFIG}"
printf "ZS=%s, MB=%s, GB=%s, PP=%s, DTYPE=%s\n" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}"
generateDSconfig "${DS_CONFIG}"
cat "${DS_CONFIG}" | jq .
}
Expand Down
20 changes: 11 additions & 9 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,20 +954,22 @@ def _add_training_args(parser):
default='adam',
choices=[
'adam',
'adam8bit',
'adamw',
'sophiag',
'sgd',
'ds.fusedlamb',
'ipex.lamb',
'ipex.fusedlamb',
'adamwschedulefree',
'apex.adam',
'apex.sgd',
'adamwschedulefree',
'sgdschedulefree',
'ds.fusedlamb',
'ds.onebitlamb',
'galoreadamw',
'adam8bit',
'galoreadamw8bit',
'galoreadamw8bitperlayer'
'galoreadamw8bitperlayer',
'ipex.fusedlamb',
'ipex.lamb',
'shampoo',
'sgd',
'sgdschedulefree',
'sophiag'
],
help='Optimizer function'
)
Expand Down
205 changes: 121 additions & 84 deletions megatron/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def get_megatron_optimizer(
param_groups
)

optimizer = None
# ---- CPU Optimizer --------------------------------------
if args.cpu_optimizer:
assert args.optimizer == 'adam', 'CPU offloading is for Adam'
if args.cpu_torch_adam:
Expand All @@ -141,52 +143,73 @@ def get_megatron_optimizer(
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)

elif args.optimizer.lower() == "galore_adamw":
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit
# redefine way to call galore_adamw
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == "galore_adamw":
# redefine way to call galore_adamw
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# implement adafactor
elif args.optimizer.lower() == "adafactor":
import transformers
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = transformers.optimization.Adafactor(
# ---- Adam --------------------------------------
elif args.optimizer == 'adam':
if args.ds_fused_adam:
# global Adam
from deepspeed.ops.adam import FusedAdam
Adam = FusedAdam
else:
Adam = torch.optim.Adam
optimizer = Adam(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
relative_step=False,
scale_parameter=False,
warmup_init=False,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# low-rank adafactor
elif args.optimizer.lower() == "galore_adafactor":
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = GaLoreAdafactor(
# ---- apex.Adam --------------------------------------------
elif str(args.optimizer).lower() == 'apex.adam':
assert get_accelerator().device_name() == 'cuda'
from apex.optimizers import FusedAdam as Adam
optimizer = Adam(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
relative_step=False,
scale_parameter=False,
warmup_init=False,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# 8-bit Adam
# ---- Adam8Bit --------------------------------------
elif args.optimizer.lower() == "adam8bit":
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW --------------------------------------
elif str(args.optimizer).lower() == 'adamw':
optimizer = torch.optim.AdamW(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# ---- AdamW: ScheduleFree -------------------------------------
elif str(args.optimizer).lower() == 'adamwschedulefree':
import schedulefree
optimizer = schedulefree.AdamWScheduleFree(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
warmup_steps=args.lr_warmup_iters,
foreach=args.schedulefree_for_each,
)
# ---- AdamW: Galore ------------------------------------------
elif args.optimizer.lower() == "galore_adamw":
from galore_torch import GaLoreAdamW
# redefine way to call galore_adamw
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# elif args.optimizer.lower() == "galore_adamw":
# from galore_torch import GaLoreAdamW
# # redefine way to call galore_adamw
# optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW: GaloRe 8Bit --------------------------------------
elif args.optimizer.lower() == "galore_adamw8bit":
from galore_torch import GaLoreAdamW8bit
optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW8bitPerLayer: GaloRE ----------------------------
elif args.optimizer.lower() == 'galore_adamw8bit_per_layer':
from galore_torch import GaLoreAdamW8bit
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
optimizer_dict = {}
for p in model.parameters():
Expand Down Expand Up @@ -219,45 +242,48 @@ def optimizer_hook(p):
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
layer_wise_flag = True
elif str(args.optimizer) == 'ipex.lamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)
elif str(args.optimizer) == 'ipex.fusedlamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
# ---- AdaFactor --------------------------------------
elif args.optimizer.lower() == "adafactor":
import transformers
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = transformers.optimization.Adafactor(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
fused=True,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
elif str(args.optimizer).lower() == 'ds.fusedlamb':
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(
# ---- GaLore: Adafactor adafactor ------------------------------------
elif args.optimizer.lower() == "galore_adafactor":
from galore_torch import GaLoreAdafactor
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = GaLoreAdafactor(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
elif str(args.optimizer).lower() == 'adamwschedulefree':
import schedulefree
optimizer = schedulefree.AdamWScheduleFree(
# ---- Apex: sgd ---------------------------------------------
elif str(args.optimizer).lower() == 'apex.sgd':
from apex.optimizers import FusedSGD as SGD
optimizer = SGD(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
warmup_steps=args.lr_warmup_iters,
foreach=args.schedulefree_for_each,
momentum=args.sgd_momentum
)
# ---- ScheduleFree: SGD -------------------------------
elif str(args.optimizer).lower() == 'sgdschedulefree':
import schedulefree
optimizer = schedulefree.SGDScheduleFree(
Expand All @@ -268,45 +294,54 @@ def optimizer_hook(p):
warmup_steps=args.lr_warmup_iters,
foreach=args.schedulefree_for_each,
)
elif str(args.optimizer).lower() == 'apex.adam':
assert get_accelerator().device_name() == 'cuda'
from apex.optimizers import FusedAdam as Adam
optimizer = Adam(
# ---- Lamb: Ipex --------------------------------------------
elif str(args.optimizer) == 'ipex.lamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
elif str(args.optimizer).lower() == 'apex.sgd':
from apex.optimizers import FusedSGD as SGD
optimizer = SGD(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum
eps=args.adam_eps,
)
elif str(args.optimizer).lower() == 'adamw':
optimizer = torch.optim.AdamW(
# ---- Lamb(Fused): Ipex ----------------------------------------
elif str(args.optimizer) == 'ipex.fusedlamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
eps=args.adam_eps,
fused=True,
)
elif args.optimizer == 'adam':
if args.ds_fused_adam:
# global Adam
from deepspeed.ops.adam import FusedAdam
Adam = FusedAdam
else:
Adam = torch.optim.Adam
optimizer = Adam(
# ---- Lamb(Fused): DeepSpeed ------------------------------------------
elif str(args.optimizer).lower() == 'ds.fusedlamb':
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
eps=args.adam_eps,
)
# ---- Shampoo ----------------------------------------
elif args.optimizer == 'shampoo':
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import AdamGraftingConfig
optimizer = DistributedShampoo(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
epsilon=1e-12,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
beta2=0.999,
epsilon=1e-08,
),
)
elif args.optimizer == 'sgd':
optimizer = torch.optim.SGD(
Expand All @@ -326,8 +361,10 @@ def optimizer_hook(p):
)
else:
raise TypeError(f'{args.optimizer} optimizer is not supported.')
assert optimizer is not None
if args.deepspeed:
return optimizer

# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.use_contiguous_buffers_in_local_ddp:
Expand Down
1 change: 0 additions & 1 deletion pretrain_gpt_alcf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Pretrain GPT"""

import time
from typing import Callable
from mpi4py import MPI
Expand Down
Loading

0 comments on commit 587aafd

Please sign in to comment.