Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull in upstream changes from argonne-lcf @ main #12

Merged
merged 9 commits into from
Dec 25, 2024
52 changes: 39 additions & 13 deletions ALCF/helpers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,13 @@ setup_run_cmd() {
# min_lr=$(python3 -c 'print(f"{2 / (10 ** 5):.8f}")')
# "--min-lr ${LR:-${min_lr}}" # 2e-5
# "--min-lr ${MIN_LR:-"2e-6"}" # 2e-5
export LR="${LR:-0.0002}"
export LR_DECAY_STYLE="${LR_DECAY_STYLE:-cosine}"
export LR_WARMUP_FRAC="${LR_WARMUP_FRAC:-0.05}"
lr_flags=(
"--lr ${LR:-0.0002}"
"--lr-decay-style ${LR_DECAY_STYLE:-cosine}"
"--lr-warmup-fraction ${LR_WARMUP_FRAC:-0.05}"
"--lr ${LR}"
"--lr-decay-style ${LR_DECAY_STYLE}"
"--lr-warmup-fraction ${LR_WARMUP_FRAC}"
)
if [[ -n "${LR_DECAY_ITERS:-}" ]]; then
lr_flags+=("--lr-decay-iters ${LR_DECAY_ITERS:-}")
Expand Down Expand Up @@ -225,9 +228,9 @@ setup_run_cmd() {
"${lr_flags[@]}"
"${custom_args[@]}"
"${llama_flags[@]}"
"${DATA_FLAGS}"
"${FLASH_ARG}"
"${TIMING_STR}"
"${TIMING_STR:-}"
"${DATA_FLAGS}"
"${TOKENIZER_FLAGS}"
"${tb_flags[@]}"
"${ds_args[@]}"
Expand Down Expand Up @@ -316,6 +319,8 @@ get_machine_name() {
else
machine="polaris"
fi
elif [[ $(hostname) == sophia* ]]; then
machine="sophia"
elif [[ $(hostname) == nid* ]]; then
machine="perlmutter"
else
Expand All @@ -325,6 +330,7 @@ get_machine_name() {
}

get_machine() {
machine=$(hostname)
if [[ $(hostname) == x4* ]]; then
machine="aurora"
elif [[ $(hostname) == x1* ]]; then
Expand All @@ -335,6 +341,8 @@ get_machine() {
else
machine="polaris"
fi
elif [[ $(hostname) == sophia* ]]; then
machine="sophia"
elif [[ $(hostname) == nid* ]]; then
machine="perlmutter"
else
Expand Down Expand Up @@ -366,7 +374,7 @@ setupSrun() {

printJobInfo() {
echo "++++++++++++++++++++++++++++++++++++++++++++++++++"
echo "- MPICH_DIR=${MPICH_DIR:-${MPI_ROOT}}"
echo "- MPICH_DIR=${MPICH_DIR:-${MPI_ROOT:-}}"
echo "- Using $(which python3)"
echo "- WORLD_SIZE:${WORLD_SIZE-}"
echo "- BACKEND: ${BE:-}"
Expand Down Expand Up @@ -406,6 +414,8 @@ setupLauncher() {
mn=$(get_machine_name)
if [[ "${mn}" == "aurora" || "${mn}" == "sunspot" ]]; then
LAUNCHER="${DIST_LAUNCH} --pmi=pmix --genvall $(which python3) -Wignore ${EXEC}"
elif [[ "${mn}" == "sophia" ]]; then
LAUNCHER="${DIST_LAUNCH} $(which python3) -Wignore ${EXEC}"
else
LAUNCHER="${DIST_LAUNCH} --genvall $(which python3) -Wignore ${EXEC}"
fi
Expand Down Expand Up @@ -512,9 +522,9 @@ get_grad_acc_steps_on_aurora() {
gas=1
elif [[ 128 -le "${nhosts}" && "${nhosts}" -lt 256 ]]; then
gas=2
elif [[ 32 -le "${nhosts}" && "${nhosts}" -lt 128 ]]; then
elif [[ 32 -lt "${nhosts}" && "${nhosts}" -lt 129 ]]; then
gas=4
elif [[ 16 -le "${nhosts}" && "${nhosts}" -lt 32 ]]; then
elif [[ 16 -le "${nhosts}" && "${nhosts}" -le 32 ]]; then
gas=8
else
gas=16
Expand Down Expand Up @@ -626,6 +636,22 @@ setParams() {
fi
echo "Setting up AWS NCCL OFI Plugin on Polaris..."
source "${WORKING_DIR}/ALCF/aws_ofi_nccl_plugin.sh" || exit
# ---- [Sophia] ----------------------
elif [[ "${mn}" == sophia* ]]; then
# export LAUNCH_CMD="${LAUNCH_CMD:-deepspeed}"
TP=${TP:-1} # TP = 2
export NCCL=${NCCL:-nccl} # NCCL
export BE="${NCCL}" # BE = NCCL
export DTYPE=${DTYPE:-bf16} # DTYPE: FP16
export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-8} # GRADIENT_ACC_STEPS
export MICRO_BATCH="${MICRO_BATCH:-$(get_batch_size_on_polaris)}"
if [[ -n "${NO_FLASH_ATTN-}" ]]; then
echo "Not using flash-attn!!"
else
FLASH_ARG="--use-flash-attn-v2"
fi
# echo "Setting up AWS NCCL OFI Plugin on Polaris..."
# source "${WORKING_DIR}/ALCF/aws_ofi_nccl_plugin.sh" || exit
# [Perlmutter]
elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then
TP="${TP:-2}"
Expand Down Expand Up @@ -896,8 +922,8 @@ buildDSconfig() {
# export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}"
export DS_CONFIG="${WORKING_DIR}/ds-configs/ds_stage${ZERO_STAGE}_mb${MICRO_BATCH}_gb${GLOBAL_BATCH}_pp${PP}_${DTYPE}.json"
mkdir -p "$(dirname "${DS_CONFIG}")"
echo "DS_CONFIG: ${DS_CONFIG}"
printf "ZS: %s, MB: %s, GB: %s, PP: %s, DTYPE: %s" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}"
printf "DS_CONFIG: %s\n" "${DS_CONFIG}"
printf "ZS=%s, MB=%s, GB=%s, PP=%s, DTYPE=%s\n" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}"
generateDSconfig "${DS_CONFIG}"
cat "${DS_CONFIG}" | jq .
}
Expand Down Expand Up @@ -1046,7 +1072,7 @@ setup_tokenizer_and_data() {
export TOKENIZER_TYPE="GPT2"
_tokenizer_flags+=("--tokenizer-type GPT2BPETokenizer")
machine=$(get_machine_name)
if [[ ${machine} == "polaris" ]]; then
if [[ ${machine} == "polaris" || ${machine} == "sophia" ]]; then
export DATA_PARENT="${DATA_PARENT:-/eagle/argonne_tpc/foremans/projects/argonne-lcf/Megatron-DeepSpeed/dataset}"
elif [[ ${machine} == "sunspot" ]]; then
export DATA_PARENT="${DATA_PARENT:-/gila/Aurora_deployment/foremans/anl_24_q2_release/Megatron-DeepSpeed/dataset}"
Expand Down Expand Up @@ -1075,7 +1101,7 @@ setup_tokenizer_and_data() {
echo "Using tokenizer: ${TOKENIZER_TYPE}. Setting up data with ${DATA_FILE_LIST:-}"
setData "${dfl}" || exit
fi
export DATA_FLAGS="${_data_flags[*]}"
export DATA_FLAGS="${_data_flags[*]:-}"
export TOKENIZER_FLAGS="${_tokenizer_flags[*]}"
printf "[setData] DATA_FLAGS: %s\n" "$(printGreen "${DATA_FLAGS}")"
printf "[setData] TOKENIZER_FLAGS: %s\n" "$(printMagenta "${TOKENIZER_FLAGS}")"
Expand Down Expand Up @@ -1113,7 +1139,7 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST]
printf "WEIGHT_SUM: %s\n" "${WEIGHT_SUM}"
printf "DFL_STEM: %s\n" "${DFL_STEM}"
printf "DATA_CACHE_PATH: %s\n" "${DATA_CACHE_PATH}"
printf "DATA_FLAGS: %s\n" "${DATA_FLAGS}"
printf "DATA_FLAGS: %s\n" "${DATA_FLAGS:-}"
echo "--------------------"
}

Expand Down
20 changes: 11 additions & 9 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,20 +954,22 @@ def _add_training_args(parser):
default='adam',
choices=[
'adam',
'adam8bit',
'adamw',
'sophiag',
'sgd',
'ds.fusedlamb',
'ipex.lamb',
'ipex.fusedlamb',
'adamwschedulefree',
'apex.adam',
'apex.sgd',
'adamwschedulefree',
'sgdschedulefree',
'ds.fusedlamb',
'ds.onebitlamb',
'galoreadamw',
'adam8bit',
'galoreadamw8bit',
'galoreadamw8bitperlayer'
'galoreadamw8bitperlayer',
'ipex.fusedlamb',
'ipex.lamb',
'shampoo',
'sgd',
'sgdschedulefree',
'sophiag'
],
help='Optimizer function'
)
Expand Down
Loading