diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 0000000000..2f9765af24 --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,35 @@ +name: python + +on: + workflow_dispatch: + pull_request: + branches: + '**' + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + strategy: + matrix: + pyVersion: ["3.7", "3.8", "3.9", "3.10"] + fail-fast: false + + runs-on: ubuntu-22.04 + container: + image: deepspeed/gh-builder:py${{ matrix.pyVersion }} + + steps: + - uses: actions/checkout@v4 + + - name: environment + run: | + which python + python --version + - name: Install Megatron-DeepSpeed + run: | + pip3 install . diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index 7720a919d9..5df9a2c7a5 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -20,7 +20,6 @@ # command for launching across all GPUs in our active PBS job. ############################################################################### - ################## # helpers_main # @@ -75,7 +74,7 @@ helpers_main() { # # - Clone + Install [`saforem2/ezpz`](https://github.com/saforem2/ezpz) # - Source [`ezpz/utils.sh`](https://github.com/saforem2/ezpz/blob/main/src/ezpz/bin/utils.sh) -# - This provides `{ezpz_setup_python, ezpz_setup_alcf}` (called below) +# - This provides `{ezpz_setup_python, ezpz_setup_job}` (called below) # # - Set runtime options # @@ -113,22 +112,23 @@ setup() { # - NGPU_PER_HOST (by magic) # - NGPUS (= NHOSTS * NGPU_PER_HOST) # 4. Use these (^) to build our launch command - ezpz_setup "$@" || exit + ezpz_setup || exit ########################################################################## install_dependencies # Set command line arguments to pass to `"${EXEC}"` - setParams "$@" || exit + setParams || exit # Create `deepspeed_config.json` from runtime params from ^ buildDSconfig || exit # Specify output directory for {logs, checkpoints, etc.} + setup_checkpoint || exit setOutput || exit # Specify additional `deepspeed` arguments (dependent on _newly created_ variables) set_args || exit # Ensure executable exists in expected path check_executable "${EXEC:-${WORKING_DIR}/pretrain_gpt_alcf.py}" - dfl="${DATA_FILE_LIST:-}" + dfl="${DATA_FILE_LIST:-"${PBS_O_WORKDIR}/ALCF/data-lists/$(get_machine_name)/dolma.txt"}" # Setup data + tokenizer via `DATA_FILE_LIST` and `TOKENIZER_TYPE` - tok="${TOKENIZER_TYPE:-Llama2}" + tok="${TOKENIZER_TYPE:-Llama2Tokenizer}" setup_tokenizer_and_data "${tok}" "${dfl}" || exit make_data || exit # Print job info @@ -138,20 +138,28 @@ setup() { # Check that were not already running, if so, exit. check_and_kill_if_running || exit # Setup run command to be executed - setup_run_cmd || exit + setup_run_cmd "$@" || exit } + ##################################################### # setup_run_cmd # # Build run command to be executed. ##################################################### setup_run_cmd() { + ############################## + # take in additional arguments + # and append them directly to + # the end of the `run_cmd` + # custom_args="$@" + custom_args=("$@") + ############################## #### Make it easy to track experiments by date ################### year="$(date "+%Y")" month="$(date "+%m")" day="$(date "+%Y-%m-%d")" - today="$(date "+%Y-%m-%d")" # kept for backwards compatibility + today="$(date "+%Y-%m-%d")" # kept for backwards compatibility started_at="$(date "+%Y-%m-%d-%H%M%S")" export YEAR="${year}" export MONTH="${month}" @@ -163,74 +171,122 @@ setup_run_cmd() { # `export LAUNCH_WITH=deepspeeed && bash train_llama_alcf.sh` ################################################################## setupLauncher "${LAUNCH_WITH:-MPICH}" || exit - TBDIR="${CKPT_DIR}/tensorboard" - mkdir -p "${TBDIR}" export data_cache_path="${CKPT_DIR}/${DATA_CACHE_PATH}" && mkdir -p "${data_cache_path}" printf "\n" echo "Using data_cache_path: ${data_cache_path}" - export DEFAULTS="\ - --split 100,0,0 \ - --log-interval 1 \ - --no-bias-gelu-fusion \ - --no-bias-dropout-fusion \ - --no-masked-softmax-fusion \ - --no-gradient-accumulation-fusion \ - --accumulate-allreduce-grads-in-fp32 \ - --use-checkpoint-opt_param-scheduler \ - --log-timers-to-tensorboard \ - --log-optimizer-states-to-tensorboard" - if [[ "${SP}" -ge 2 ]]; then - export DEFAULTS="${DEFAULTS} --ds-sequence-parallel-size ${SP} --force-ds-sequence-parallel" + TRAIN_SPLIT="${TRAIN_SPLIT:-100}" + VAL_SPLIT="${VAL_SPLIT:-0}" + TEST_SPLIT="${TEST_SPLIT:-0}" + LOG_INTERVAL="${LOG_INTERVAL:-1}" + DEFAULTS=( + "--split ${TRAIN_SPLIT},${VAL_SPLIT},${TEST_SPLIT}" + "--log-interval ${LOG_INTERVAL}" + "--no-bias-gelu-fusion" + "--no-bias-dropout-fusion" + "--no-masked-softmax-fusion" + "--no-gradient-accumulation-fusion" + "--accumulate-allreduce-grads-in-fp32" + ) + # export DEFAULTS="\ + # --split ${TRAIN_SPLIT},${VAL_SPLIT},${TEST_SPLIT} \ + # --log-interval ${LOG_INTERVAL} \ + # --no-bias-gelu-fusion \ + # --no-bias-dropout-fusion \ + # --no-masked-softmax-fusion \ + # --no-gradient-accumulation-fusion \ + # --accumulate-allreduce-grads-in-fp32" + # OVERRIDE_CKPT_OPT_PARAM="${OVERRIDE_CKPT_OPT_PARAM:-}" + if [[ -z "${OVERRIDE_CKPT_OPT_PARAM:-}" ]]; then + DEFAULTS+=("--use-checkpoint-opt_param-scheduler") + fi + if [[ "${SP}" -gt 1 ]]; then + DEFAULTS+=( + "--ds-sequence-parallel-size ${SP}" + "--force-ds-sequence-parallel" + ) fi ################################################################## # WARN: to disable Llama-type architectures, toggle via: # `NO_LLAMA=1 bash train_llama_alcf.sh` ################################################################## - if [[ -z "${NO_LLAMA:-}" ]]; then - llama_flags="${LLAMA_ARGS}\ - --num-key-value-heads ${NUM_KV_HEAD} \ - --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ - " + LLAMA_ARGS="" + if [[ "${SP}" == 1 ]]; then + export LLAMA_ARGS="${LLAMA_ARGS} " else - echo "!! Running in NO_LLAMA MODE !!" - llama_flags="" + export LLAMA_ARGS="" + echo "NOT USING ROTARY EMBEDDINGS! LLAMA_ARGS=${LLAMA_ARGS}" + fi + if [[ -z "${NO_LLAMA:-}" ]]; then + llama_flags=( + "--swiglu" + "--hidden-dropout 0" + "--attention-dropout 0" + "--normalization rmsnorm" + "--disable-bias-linear" + "--no-query-key-layer-scaling" + "--use-rotary-position-embeddings" + "--untie-embeddings-and-output-weights" + "--num-key-value-heads ${NUM_KV_HEAD}" + "--ffn-hidden-size ${FFN_HIDDEN_SIZE}" + ) fi - export run_cmd=" - ${LAUNCHER} \ - --${DTYPE} \ - ${DEFAULTS} \ - --optimizer ${OPT} \ - --adam-beta1=${ADAM_BETA1} \ - --adam-beta2=${ADAM_BETA2} \ - --adam-eps=${ADAM_EPS} \ - --weight-decay=${WEIGHT_DECAY} \ - --save ${CKPT_DIR} \ - --load ${CKPT_DIR} \ - --seq-length ${SEQ} \ - --num-layers ${NLAYERS} \ - --hidden-size ${HIDDEN} \ - --tensorboard-dir ${TBDIR} \ - --train-iters ${TRAIN_ITERS} \ - --eval-iters ${EVAL_ITERS} \ - --distributed-backend ${BE} \ - --num-attention-heads ${HEADS} \ - --save-interval ${SAVE_INTERVAL} \ - --eval-interval ${EVAL_INTERVAL} \ - --max-position-embeddings ${SEQ} \ - --micro-batch-size ${MICRO_BATCH} \ - --tensor-model-parallel-size ${TP} \ - --global-batch-size ${GLOBAL_BATCH} \ - --pipeline-model-parallel-size ${PP} \ - --data-cache-path ${data_cache_path} \ - ${DATA_FLAGS} \ - ${LR_ARGS} \ - ${llama_flags} \ - ${FLASH_ARG} \ - ${TIMING_STR} \ - ${TOKENIZER_FLAGS} \ - ${ds_args} \ - ${gpt_args[*]} - " + + TENSORBARD_ARGS=() + if [[ -z "${USE_TENSORBARD:-}" ]]; then + TBDIR="${CKPT_DIR}/tensorboard" + mkdir -p "${TBDIR}" + # --log-timers-to-tensorboard \ + # --log-optimizer-states-to-tensorboard" + # --tensorboard-dir ${TBDIR} \ + TENSORBARD_ARGS+=( + "--log-timers-to-tensorboard" + "--log-optimizer-states-to-tensorboard" + "--tensorboard-dir ${TBDIR}" + ) + fi + dfl_fallback="${DATA_FILE_LIST:-${PBS_O_WORKDIR}/ALCF/data-lists/$(get_machine_name)/dolma.txt}" + export ADAM_BETA1="${ADAM_BETA1:-0.9}" + export ADAM_BETA2="${ADAM_BETA2:-0.95}" + export ADAM_EPS="${ADAM_EPS:-0.00001}" # 1 * 10^{-5} + export run_cmd=( + "${LAUNCHER}" + "--${DTYPE}" + "${DEFAULTS[@]}" + "--optimizer ${OPT}" + "--save ${CKPT_DIR}" + "--load ${CKPT_DIR}" + "--seq-length ${SEQ}" + "--num-layers ${NLAYERS}" + "--hidden-size ${HIDDEN}" + "--train-iters ${TRAIN_ITERS}" + "--eval-iters ${EVAL_ITERS}" + "--distributed-backend ${BE}" + "--adam-beta1 ${ADAM_BETA1:-0.9}" + "--adam-beta2 ${ADAM_BETA2:-0.95}" + "--adam-eps ${ADAM_EPS:-0.00001}" + "--clip-grad ${CLIP_GRAD:-1.0}" + "--weight-decay ${WEIGHT_DECAY:-0.1}" + "--num-attention-heads ${HEADS}" + "--save-interval ${SAVE_INTERVAL}" + "--eval-interval ${EVAL_INTERVAL}" + "--max-position-embeddings ${SEQ}" + "--micro-batch-size ${MICRO_BATCH}" + "--tensor-model-parallel-size ${TP}" + "--global-batch-size ${GLOBAL_BATCH}" + "--pipeline-model-parallel-size ${PP}" + "--data-cache-path ${data_cache_path}" + "--data-file-list ${DATA_FILE_LIST:-${dfl_fallback}}" + "${TENSORBARD_ARGS[@]}" + "${DATA_FLAGS}" + "${LR_ARGS}" + "${llama_flags[@]}" + "${FLASH_ARG}" + "${TIMING_STR}" + "${TOKENIZER_FLAGS}" + "${ds_args[@]}" + "${gpt_args[@]}" + "${custom_args[@]}" + ) } save_dotenv() { @@ -243,7 +299,7 @@ save_dotenv() { module list dotenv_file="${outdir}/.env" echo "Saving environment to ${dotenv_file}" - printenv | grep -v "LS_COLORS" > "${dotenv_file}" + printenv | grep -v "LS_COLORS" >"${dotenv_file}" export DOTENV_FILE="${dotenv_file}" fi } @@ -298,29 +354,26 @@ get_machine() { printf "Running on: %s\n" "$(printBlue "${MACHINE}")" } - check_and_kill_if_running() { RUNNING_PIDS=$(lsof -i:29500 -Fp | head -n 1 | sed 's/^p//') - if [[ -n "${RUNNING_PIDS}" ]]; - then echo "Caught ${RUNNING_PIDS}" && kill "${RUNNING_PIDS}"; + if [[ -n "${RUNNING_PIDS}" ]]; then + echo "Caught ${RUNNING_PIDS}" && kill "${RUNNING_PIDS}" else echo "Not currently running. Continuing!" fi } - setupSrun() { if [[ $(hostname) == login* || $(hostname) == nid* ]]; then export NHOSTS="${SLURM_NNODES:-1}" export NGPU_PER_HOST="${SLURM_GPUS_ON_NODE:-$(nvidia-smi -L | wc -l)}" - export NGPUS="$(( NHOSTS * NGPU_PER_HOST ))" + export NGPUS="$((NHOSTS * NGPU_PER_HOST))" export SRUN_EXEC="srun --gpus ${NGPUS} --gpus-per-node ${NGPU_PER_HOST} -N ${NHOSTS} -n ${NGPUS} -l -u --verbose" else echo "Skipping setupSrun() on $(hostname)" fi } - printJobInfo() { echo "++++++++++++++++++++++++++++++++++++++++++++++++++" echo "- MPICH_DIR=${MPICH_DIR:-${MPI_ROOT}}" @@ -377,7 +430,6 @@ setupLauncher() { printf " %s" "$(printMagenta "${LAUNCHER}")" } - set_lr_args() { LR_ARGS="--lr ${LR} --lr-decay-style cosine" if [[ -n "${LR_DECAY_ITERS:-}" ]]; then @@ -390,7 +442,6 @@ set_lr_args() { export LR_ARGS="${LR_ARGS}" } - ######################################################################### # `get_batch_size_on_polaris`: Identify MICRO_BATCH to use on Polaris. # @@ -407,8 +458,8 @@ set_lr_args() { ######################################################################### get_batch_size_on_polaris() { if [[ $(hostname) == x3* ]]; then - nhosts=$(wc -l < "${HOSTFILE:-${PBS_NODEFILE}}") - if [[ "${nhosts}" == 1 || "${nhosts}" == 2 ]]; then + nhosts=$(wc -l <"${HOSTFILE:-${PBS_NODEFILE}}") + if [[ "${nhosts}" == 1 || "${nhosts}" == 2 ]]; then mbs=1 elif [[ "${nhosts}" -ge 3 && "${nhosts}" -le 7 ]]; then mbs=2 @@ -420,9 +471,9 @@ get_batch_size_on_polaris() { } _get_num_hosts_from_hostfile() { - if [[ "$#" == 1 ]]; then + if [[ "$#" == 1 ]]; then if [[ -f "$1" ]]; then - nhosts=$(wc -l < "$1") + nhosts=$(wc -l <"$1") echo "${nhosts}" else exit 1 @@ -460,7 +511,7 @@ get_grad_acc_steps_on_aurora() { echo "Expected exactly 0 or 1 arguments, received: $#" exit 1 fi - nhosts=$(wc -l < "${hf}") + nhosts=$(wc -l <"${hf}") if [[ 64 -le "${nhosts}" ]]; then gas=1 elif [[ 32 -le "${nhosts}" && "${nhosts}" -lt 64 ]]; then @@ -473,6 +524,33 @@ get_grad_acc_steps_on_aurora() { echo "${gas}" } +set_ccl_vars_on_aurora() { + export CCL_KVS_MODE=mpi + export CCL_CONFIGURATION_PATH="" + export CCL_CONFIGURATION=cpu_gpu_dpcpp + # export CCL_ROOT=/tmp/oneccl/ + # export LD_LIBRARY_PATH=${CCL_ROOT}/lib:$LD_LIBRARY_PATH + # export CPATH=${CCL_ROOT}/include:$CPATH + # export LIBRARY_PATH=${CCL_ROOT}/lib:$LIBRARY_PATH + export CCL_KVS_CONNECTION_TIMEOUT=3600 + export FI_CXI_RX_MATCH_MODE=hybrid + export CCL_BCAST=double_tree + + export ZE_ENABLE_PCI_ID_DEVICE_ORDER=1 + export CCL_PROCESS_LAUNCHER=pmix # Required by Aurora mpich + export FI_PROVIDER=cxi # Required by Aurora mpich + export PALS_PMI=pmix # Required by Aurora mpich + export CCL_ATL_TRANSPORT=mpi # Required by Aurora mpich + export TORCH_LLM_ALLREDUCE=1 + export CCL_SYCL_ESIMD=1 + export CCL_ALLGATHERV_MEDIUM_SIZE_THRESHOLD=0 # Required by current oneCCL (MLSL-2881) + export CCL_ENABLE_SYCL_KERNELS=1 + export CCL_WORKER_AFFINITY=5,13,21,29,37,45,57,65,73,81,89,97 + export CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD=32768 + export FI_CXI_DEFAULT_CQ_SIZE=1048576 + export FI_CXI_RX_MATCH_MODE=hybrid + export CCL_BCAST=double_tree +} ############################################################################## # setParams @@ -487,26 +565,29 @@ get_grad_acc_steps_on_aurora() { ############################################################################## setParams() { FLASH_ARG="" - LLAMA_ARGS="--attention-dropout 0 --hidden-dropout 0" # ---- [Parallelism Settings] -------------------------------------------+ # ------ [Aurora] -------||------ [SunSpot] ------------- # if [[ $(hostname) == x4* || $(hostname) == x1* ]]; then mn=$(get_machine_name) if [[ "${mn}" == "aurora" || "${mn}" == "sunspot" ]]; then - TP=${TP:-1} # TP = 1 + TP=${TP:-1} # TP = 1 export SAVE_INTERVAL="${SAVE_INTERVAL:-20}" - export CCL=${CCL:-ccl} # CCL - export BE="${CCL}" # COMMUNICATION BACKEND = CCL - export DTYPE=${DTYPE:-bf16} # DTYPE: bf16 + export CCL=${CCL:-ccl} # CCL + export BE="${CCL}" # COMMUNICATION BACKEND = CCL + export DTYPE=${DTYPE:-bf16} # DTYPE: bf16 # export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} # GRADIENT_ACC_STEPS - gas=$(get_grad_acc_steps_on_aurora "$@") + gas=$(get_grad_acc_steps_on_aurora "${PBS_NODEFILE:-${HOSTFILE:-${hostfile}}}") export GRAD_ACC_STEPS="${GRAD_ACC_STEPS:-${gas}}" # export GRAD_ACC_STEPS="${GRAD_ACC_STEPS:-$(get_grad_acc_steps_on_aurora "$@)}" echo "[setParams] Using GRAD_ACC_STEPS: ${GRAD_ACC_STEPS}" - MICRO_BATCH=${MICRO_BATCH:-4} # MICRO_BATCH = 4 - export CCL_PROCESS_LAUNCHER=pmix - export CCL_ATL_TRANSPORT=mpi - ###################################################################### + MICRO_BATCH=${MICRO_BATCH:-4} # MICRO_BATCH = 4 + #### [sam: 08/17/2024] ########################################## + # Use best set of CCL env vars from Gordon Bell runs on Aurora + set_ccl_vars_on_aurora + ################################################################# + #### [sam: 06/20/2024] ############################################### + # export CCL_PROCESS_LAUNCHER=pmix + # export CCL_ATL_TRANSPORT=mpi # !XXX: USE KEY VALUE STORE FIX ON AURORA [2024-06-20] # use_kvs_fix_on_aurora # <-- why are these different from those in update_ccl_env_vars_aurora ?? # update_ccl_env_vars_aurora @@ -528,12 +609,12 @@ setParams() { # elif [[ $(hostname) == x3* ]]; then elif [[ "${mn}" == "polaris" || "${mn}" == "sirius" ]]; then # export LAUNCH_CMD="${LAUNCH_CMD:-deepspeed}" - TP=${TP:-1} # TP = 2 - export NCCL=${NCCL:-nccl} # NCCL - export BE="${NCCL}" # BE = NCCL + TP=${TP:-1} # TP = 2 + export NCCL=${NCCL:-nccl} # NCCL + export BE="${NCCL}" # BE = NCCL # export DTYPE=${DTYPE:-bf16} # DTYPE: BF16 ?? - export DTYPE=${DTYPE:-fp16} # DTYPE: FP16 - export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-8} # GRADIENT_ACC_STEPS + export DTYPE=${DTYPE:-fp16} # DTYPE: FP16 + export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-8} # GRADIENT_ACC_STEPS # NOTE: MICRO_BATCH is exported below # MICRO_BATCH=${MICRO_BATCH:-2} # MICRO_BATCH = 8 export MICRO_BATCH="${MICRO_BATCH:-$(get_batch_size_on_polaris)}" @@ -546,7 +627,7 @@ setParams() { source "${WORKING_DIR}/ALCF/aws_ofi_nccl_plugin.sh" || exit # +--------[Perlmutter]---------------------------------+ # elif [[ $(hostname) == login* || $(hostname) == nid* ]]; then -elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then + elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then TP="${TP:-2}" export NCCL="${NCCL:-nccl}" export BE="${NCCL}" @@ -565,47 +646,47 @@ elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then export FLASH_ARG="${FLASH_ARG}" export DTYPE="${DTYPE:-bf16}" export OPT="${OPT:-adamw}" - export ADAM_BETA1="${ADAM_BETA1:-0.9}" - export ADAM_BETA2="${ADAM_BETA2:-0.95}" - export ADAM_EPS="${ADAM_EPS:-0.00001}" # 1 * 10^{-5} + # export ADAM_BETA1="${ADAM_BETA1:-0.9}" + # export ADAM_BETA2="${ADAM_BETA2:-0.95}" + # export ADAM_EPS="${ADAM_EPS:-0.00001}" # 1 * 10^{-5} export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" export HOSTFILE="${HOSTFILE:-${PBS_NODEFILE}}" - NHOSTS=$(wc -l < "${HOSTFILE}") + NHOSTS=$(wc -l <"${HOSTFILE}") if [[ -z "${NGPU_PER_HOST:-}" ]]; then NGPU_PER_HOST=$(python3 -c 'import ezpz as ez; print(ez.get_gpus_per_node())') fi export NGPU_PER_HOST="${NGPU_PER_HOST}" - export WORLD_SIZE="${WORLD_SIZE:-$(( NHOSTS * NGPU_PER_HOST ))}" + export WORLD_SIZE="${WORLD_SIZE:-$((NHOSTS * NGPU_PER_HOST))}" # +---[Llama2 7B Config]--------------------------------------------------+ # export MODEL_KEY="Llama-7B" - export HEADS=${HEADS:-${NHEADS:-32}} # NUMBER OF ATEN HEADS - export NLAYERS=${NLAYERS:-${NUM_LAYERS:-32}} # NUMBER OF LAYERS - export HIDDEN=${HIDDEN:-4096} # HIDDEN SIZE - export NUM_KV_HEAD=${NUM_KV_HEAD:-8} # GROUP ATTENTION - export FFN_HIDDEN_SIZE=${FFN_HIDDEN_SIZE:-11008} # FFN HIDDEN SIZE + export HEADS=${HEADS:-${NHEADS:-32}} # NUMBER OF ATEN HEADS + export NLAYERS=${NLAYERS:-${NUM_LAYERS:-32}} # NUMBER OF LAYERS + export HIDDEN=${HIDDEN:-4096} # HIDDEN SIZE + export NUM_KV_HEAD=${NUM_KV_HEAD:-8} # GROUP ATTENTION + export FFN_HIDDEN_SIZE=${FFN_HIDDEN_SIZE:-11008} # FFN HIDDEN SIZE # +---[Run Settings]------------------------------------------------------+ - export SEQ=${SEQ:-4096} # SEQ_LEN: 4096 - export ZERO_STAGE=${ZERO_STAGE:-1} # ZERO OFFLOADING STAGE - export MICRO_BATCH=${MICRO_BATCH:-8} # MICRO BATCH SIZE - export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} # GRADIENT ACCUMULATION STEPS - export EVAL_ITERS="${EVAL_ITERS:-10}" # NUMBER OF EVAL ITERS TO RUN - export EVAL_INTERVAL="${EVAL_INTERVAL:-50000}" # HOW FREQUENTLY TO RUN EVAL - export SAVE_INTERVAL=${SAVE_INTERVAL:-50} # HOW FREQUENTLY TO SAVE CKPTS - export TIMING_LOG_LEVEL="${TIMING_LOG_LEVEL:-1}" # TIMING VERBOSITY IN LOGS - export ACT_CKPT_NUM_LAYERS="${ACT_CKPT_NUM_LAYERS:-1}" # NUM LAYERS TO CHECKPOINT ACTIVATIONS - export USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING:-1} # USE ACTIVATION CHECKPOINTING ? - export GLOBAL_BATCH_MAX=$(( WORLD_SIZE * MICRO_BATCH * GRAD_ACC_STEPS / TP / PP / SP )) # MAX GLOBAL BATCH SIZE - export GLOBAL_BATCH="${GLOBAL_BATCH:-${GLOBAL_BATCH_MAX}}" # WILL USE MAX IF NOT SET IN ENVIRONMENT + export SEQ=${SEQ:-4096} # SEQ_LEN: 4096 + export ZERO_STAGE=${ZERO_STAGE:-1} # ZERO OFFLOADING STAGE + export MICRO_BATCH=${MICRO_BATCH:-8} # MICRO BATCH SIZE + export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} # GRADIENT ACCUMULATION STEPS + export EVAL_ITERS="${EVAL_ITERS:-10}" # NUMBER OF EVAL ITERS TO RUN + export EVAL_INTERVAL="${EVAL_INTERVAL:-50000}" # HOW FREQUENTLY TO RUN EVAL + export SAVE_INTERVAL=${SAVE_INTERVAL:-50} # HOW FREQUENTLY TO SAVE CKPTS + export TIMING_LOG_LEVEL="${TIMING_LOG_LEVEL:-1}" # TIMING VERBOSITY IN LOGS + export ACT_CKPT_NUM_LAYERS="${ACT_CKPT_NUM_LAYERS:-1}" # NUM LAYERS TO CHECKPOINT ACTIVATIONS + export USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING:-1} # USE ACTIVATION CHECKPOINTING ? + export GLOBAL_BATCH_MAX=$((WORLD_SIZE * MICRO_BATCH * GRAD_ACC_STEPS / TP / PP / SP)) # MAX GLOBAL BATCH SIZE + export GLOBAL_BATCH="${GLOBAL_BATCH:-${GLOBAL_BATCH_MAX}}" # WILL USE MAX IF NOT SET IN ENVIRONMENT # export TRAIN_ITER=${TRAIN_ITER:-317892} # NUMBER OF TRAIN ITERS if [[ -z "${TRAIN_ITERS:-${TRAIN_ITER:-}}" ]]; then export TRAIN_TOKENS=${TRAIN_TOKENS:-2000000000000} - export TRAIN_ITERS=$(( TRAIN_TOKENS / SEQ / GLOBAL_BATCH )) - printf "TRAIN_TOKENS=%s (=%sB tokens)\n" "${TRAIN_TOKENS}" "$(( TRAIN_TOKENS / 10**9 ))" + export TRAIN_ITERS=$((TRAIN_TOKENS / SEQ / GLOBAL_BATCH)) + printf "TRAIN_TOKENS=%s (=%sB tokens)\n" "${TRAIN_TOKENS}" "$((TRAIN_TOKENS / 10 ** 9))" printf "TRAIN_ITERS=%s\n" "${TRAIN_ITERS}" else export TRAIN_ITERS="${TRAIN_ITERS:-${TRAIN_ITER:-}}" fi - export MODEL_TYPE="llama-gb${GLOBAL_BATCH}-seq${SEQ}-pp${PP}-tp${TP}-${NLAYERS}layers-${HEADS}heads-${HIDDEN}hidden" # STRING FOR IDENTIFYING MODEL + export MODEL_TYPE="llama-gb${GLOBAL_BATCH}-seq${SEQ}-pp${PP}-tp${TP}-${NLAYERS}layers-${HEADS}heads-${HIDDEN}hidden" # STRING FOR IDENTIFYING MODEL # NOTE: [2024-07-10] ##################################################### # - [sam]: For whatever reason, it seems that using # sequence-parallelism (SP) > 1 is INCOMPATIBLE with @@ -613,30 +694,24 @@ elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then # # For this reason, we only use the default LLAMA_ARGS when SP=0. ########################################################################## - if [[ "${SP}" == 1 ]]; then - export LLAMA_ARGS="${LLAMA_ARGS} --no-query-key-layer-scaling --use-rotary-position-embeddings --untie-embeddings-and-output-weights --swiglu --normalization rmsnorm --disable-bias-linear" - else - export LLAMA_ARGS="" - echo "NOT USING ROTARY EMBEDDINGS! LLAMA_ARGS=${LLAMA_ARGS}" - fi # -----[Learning Rate Settings]-------------------------------------------- - export LR=${LR:-0.0003} # LEARNING_RATE - export LR_WARMUP_FRAC=${LR_WARMUP_FRAC:-0.05} # LEARNING RATE WARMUP + export LR=${LR:-0.0003} # LEARNING_RATE + export LR_WARMUP_FRAC=${LR_WARMUP_FRAC:-0.05} # LEARNING RATE WARMUP export LR_DECAY_ITERS=${LR_DECAY_ITERS:-} # LR DECAY ITERS set_lr_args # -----[Learning Rate Settings]-------------------------------------------- - if [[ "${TIMING_LOG_LEVEL}" -ge 1 ]]; then + # if [[ "${TIMING_LOG_LEVEL:-1}" -gt 1 ]]; then + if [[ "${TIMING_LOG_LEVEL:-1}" -gt 1 ]]; then TIMING_STR="\ - --timing-log-level ${TIMING_LOG_LEVEL} \ - --log-timers-to-tensorboard \ - --log-optimizer-states-to-tensorboard \ - " + --timing-log-level ${TIMING_LOG_LEVEL}" + # --log-timers-to-tensorboard \ + # --log-optimizer-states-to-tensorboard \ + # " else TIMING_STR="" fi } - ############################################## # set_args # @@ -645,19 +720,31 @@ elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then ############################################## set_args() { # ---- Set DeepSpeed arguments -------------------------------- - ds_args=" " - ds_args=" --deepspeed ${ds_args}" - if [[ $PP == 1 ]]; then - ds_args=" --no-pipeline-parallel ${ds_args}" + ds_args=( + "--deepspeed" + ) + if [[ "${PP:-1}" == 1 ]]; then + ds_args+=("--no-pipeline-parallel") fi - ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" - ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" + ds_args+=("--deepspeed_config=${DS_CONFIG}") + ds_args+=("--zero-stage=$ZERO_STAGE") if [[ "${ZERO_STAGE}" == 3 ]]; then - ds_args="--use-mics ${ds_args}" + ds_args+=("--use-mics") fi + # ds_args=" " + # ds_args=" --deepspeed ${ds_args}" + # if [[ $PP == 1 ]]; then + # ds_args=" --no-pipeline-parallel ${ds_args}" + # fi + # ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" + # ds_args="--zero-stage=$ZERO_STAGE ${ds_args}" + # if [[ "${ZERO_STAGE}" == 3 ]]; then + # ds_args="--use-mics ${ds_args}" + # fi if [[ "$USE_ACTIVATION_CHECKPOINTING" == 1 ]]; then echo "!! Caught USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING} !!" - ds_args=" --deepspeed-activation-checkpointing ${ds_args}" + ds_args+=("--deepspeed-activation-checkpointing") + # ds_args=" --deepspeed-activation-checkpointing ${ds_args}" # --checkpoint-activations \ # --deepspeed-activation-checkpointing fi @@ -676,20 +763,18 @@ set_args() { export gpt_args } - make_ds_hostfile() { export GPUS_PER_NODE="${GPUS_PER_NODE:-${NGPU_PER_HOST:-${SLURM_GPUS_ON_NODE:-$(nvidia-smi -L | wc -l)}}}" # ---- Make MPICH hostfile ---------------- hf="${HOSTFILE:-${PBS_NODEFILE}}" export hostfile_mpich=hostfile_mpich - cat "${hf}" > "${hostfile_mpich}" + cat "${hf}" >"${hostfile_mpich}" # ---- Make DeepSpeed hostfile ------------------- export hostfile_deepspeed=hostfile_deepspeed - cat "${hf}" > "${hostfile_deepspeed}" + cat "${hf}" >"${hostfile_deepspeed}" sed -e "s/$/ slots=${GPUS_PER_NODE}/" -i "${hostfile_deepspeed}" } - ########################################### # ezpz_setup # @@ -697,12 +782,12 @@ make_ds_hostfile() { # to `"${WORKING_DIR}/deps/ezpz/"` # # 2. Source [`ezpz/src/ezpz/bin/utils.sh`](https://github.com/saforem2/ezpz/blob/main/src/ezpz/bin/utils.sh) -# - This provides `{ezpz_setup_python, ezpz_setup_alcf}` (called below) +# - This provides `{ezpz_setup_python, ezpz_setup_job}` (called below) # # 3. Call `ezpz_setup_python` (from `ezpz/bin/utils.sh`): # - This will setup conda + virtual enviroment # -# 4. Call `ezpz_setup_alcf` (from `ezpz/bin/utils.sh`): +# 4. Call `ezpz_setup_job` (from `ezpz/bin/utils.sh`): # - This will parse `$PBS_*` variables and build launch cmd # # 3. Call `_ezpz_install` (from `Megatron-DeepSpeed/ALCF/helpers.sh`): @@ -723,7 +808,7 @@ ezpz_setup() { # shellcheck source=../deps/ezpz/src/ezpz/bin/utils.sh source "${ezdir}/src/ezpz/bin/utils.sh" || exit ezpz_setup_python - ezpz_setup_alcf "$@" + ezpz_setup_job "$@" ezpz_pip_loc=$(python3 -m pip list | grep ezpz | awk '{print $NF}') if [[ -z "${ezpz_pip_loc:-}" ]]; then printf "[ezpz_install] Installing ezpz from %s\n" "${ezdir}" @@ -750,21 +835,20 @@ ezpz_test() { # saveDSenv # # Save important environment variables to .deepspeed_env, which will be -# forwarded to ALL ranks with DeepSpeed +# forwarded to ALL ranks with DeepSpeed ############################################################################ saveDSenv() { echo "Saving {PATH, LD_LIBRARY_PATH, htt{p,ps}_proxy, CFLAGS, PYTHONUSERBASE} to .deepspeed_env" { - echo "PATH=${PATH}" ; - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ; - echo "http_proxy=${http_proxy:-}" ; - echo "https_proxy=${https_proxy:-}" ; - echo "CFLAGS=${CFLAGS}" ; - echo "PYTHONUSERBASE=$PYTHONUSERBASE" ; - } > .deepspeed_env + echo "PATH=${PATH}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" + echo "http_proxy=${http_proxy:-}" + echo "https_proxy=${https_proxy:-}" + echo "CFLAGS=${CFLAGS}" + echo "PYTHONUSERBASE=$PYTHONUSERBASE" + } >.deepspeed_env } - get_output_prefix() { # ---- Specify output location -------------------------------- pre="ws${WORLD_SIZE}_ds_stage${ZERO_STAGE}_nl${NLAYERS}" @@ -773,7 +857,8 @@ get_output_prefix() { pre="${pre}_sp${SP}_pp${PP}_tp${TP}_${DTYPE}_opt${OPT}" pre="${pre}_lr${LR}_lwf${LR_WARMUP_FRAC}" if [[ -n "${TOKENIZER_TYPE:-}" ]]; then - pre="${pre}_tok${TOKENIZER_TYPE}" + _tok=$(echo "${TOKENIZER_TYPE}" | sed 's/Tokenizer//g') # noqa + pre="${pre}_tok${_tok}" fi if [[ -n "${LR_DECAY_ITERS}" ]]; then pre="${pre}_ldi${LR_DECAY_ITERS}" @@ -791,9 +876,21 @@ setOutput() { OUTPUT_DIR="logs/${OUTPUT_PREFIX}/$(date +%Y%m%d-%H%M%S)_${WORLD_SIZE}_${HOSTNAME}" export OUTPUT_DIR="${OUTPUT_DIR}" && mkdir -p "${OUTPUT_DIR}" export OUTPUT_LOG="${OUTPUT_DIR}/output.log" - export CKPT_DIR="checkpoints/${OUTPUT_PREFIX}" - echo "${OUTPUT_LOG}" >> "logs/latest" + echo "${OUTPUT_LOG}" >>"logs/latest" printf "\n Please see logs at: %s\n" "$(printGreen "${OUTPUT_DIR}")" +} + +get_checkpoint_dir() { + if [[ -n "${CKPT_DIR:-}" ]]; then + echo "${CKPT_DIR}" + else + echo "checkpoints/$(get_output_prefix)" + fi +} + +setup_checkpoint() { + ckpt_dir=$(get_checkpoint_dir) + export CKPT_DIR="${ckpt_dir}" printf "Checkpoints will be saved to: %s\n" "$(printYellow "${CKPT_DIR}")" } @@ -801,7 +898,7 @@ setOutput() { # Build DeepSpeed config and write to .json ############################################# buildDSconfig() { - export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}" + # export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}" export DS_CONFIG="${WORKING_DIR}/ds-configs/ds_stage${ZERO_STAGE}_mb${MICRO_BATCH}_gb${GLOBAL_BATCH}_pp${PP}_${DTYPE}.json" mkdir -p "$(dirname "${DS_CONFIG}")" echo "DS_CONFIG: ${DS_CONFIG}" @@ -809,7 +906,6 @@ buildDSconfig() { generateDSconfig "${DS_CONFIG}" } - ############################################################################### # sumWeights # @@ -844,7 +940,6 @@ make_data() { cd "${mdir}" && make && cd - } - ############################################################################## # install_dependencies # @@ -853,7 +948,7 @@ make_data() { install_dependencies() { depsfile="${WORKING_DIR}/ALCF/requirements/requirements.txt" echo "[install_dependencies] Ensuring all dependencies from ${depsfile} installed..." - python3 -m pip install -r "${depsfile}" --require-virtualenv 1> /dev/null + python3 -m pip install -r "${depsfile}" --require-virtualenv 1>/dev/null if [[ ! -x "$(command -v deepspeed)" ]]; then mn=$(get_machine_name) # if [[ "${mn}" == aurora* || "${mn}" == sunspot* ]]; then @@ -864,32 +959,6 @@ install_dependencies() { fi } -###################################################################### -# install_deepspeed_for_xpu -# -# Install microsoft/DeepSpeed on PVC -# -# This will: -# 1. Clone rep -# 2. Checkout appropriate branch -# 3. Install into virtual environment -###################################################################### -install_deepspeed_for_xpu() { - # python3 -m pip install "torch==2.1.0.post2" torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 oneccl_bind_pt==2.1.300+xpu --extra-index-url "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" - echo "Building + Installing DeepSpeed on $(hostname)" - outdir="${WORKING_DIR}/deps/DeepSpeed" - mkdir -p "${outdir}" - git clone https://github.com/microsoft/DeepSpeed.git "${outdir}" - cd "${outdir}" || exit - echo "[install_deepspeed_for_xpu] !! pwd: $(pwd)" - python3 -m pip install --require-virtualenv -r requirements/requirements.txt 1> /dev/null - python3 -m pip install xgboost "numpy<2" --force-reinstall --upgrade --require-virtualenv 1> /dev/null - python setup.py develop 1> /dev/null - cd "${WORKING_DIR}" - echo "[install_deepspeed_for_xpu] !! pwd: $(pwd)" -} - - ################################################# # Fix for distributed key value store on Aurora ################################################# @@ -946,7 +1015,6 @@ check_executable() { fi } - ###################################################################### # `makeHostiles`: # Detect if `HOSTFILE` set in active environment. @@ -955,7 +1023,7 @@ check_executable() { ###################################################################### makeHostfiles() { if [[ -n "${HOSTFILE}" ]]; then - printf "!! USING CUSTOM HOSTFILE FROM: %s" "${HOSTFILE}" + printf "!! USING CUSTOM HOSTFILE FROM: %s" "${HOSTFILE}" else make_ds_hostfile fi @@ -976,9 +1044,11 @@ setup_tokenizer_and_data() { fi echo "Setting up tokenizer with ${tok}" echo "Using data_file_list: ${dfl}" + _data_flags=() + _tokenizer_flags=() if [[ ${tok} == gpt* || ${tok} == GPT* ]]; then export TOKENIZER_TYPE="GPT2" - export TOKENIZER_FLAGS="--tokenizer-type GPT2BPETokenizer" + _tokenizer_flags+=("--tokenizer-type GPT2BPETokenizer") machine=$(get_machine_name) if [[ ${machine} == "polaris" ]]; then export DATA_PARENT="${DATA_PARENT:-/eagle/argonne_tpc/foremans/projects/argonne-lcf/Megatron-DeepSpeed/dataset}" @@ -992,18 +1062,25 @@ setup_tokenizer_and_data() { export VOCAB_FILE="${DATA_PARENT}/gpt2-vocab.json" export MERGE_FILE="${DATA_PARENT}/gpt2-merges.txt" export DATA_PATH="${DATA_PARENT}/BookCorpusDataset_text_document" - export DATA_FLAGS="--data-path ${DATA_PATH} --vocab-file ${VOCAB_FILE} --merge-file ${MERGE_FILE}" + _data_flags+=( + "--data-path ${DATA_PATH}" + "--vocab-file ${VOCAB_FILE}" + "--merge-file ${MERGE_FILE}" + ) else - export DATA_FLAGS="" - export TOKENIZER_TYPE="Llama2" - tm="${WORKING_DIR}/ALCF/tokenizer.model" # fallback: Megatron-DeepSpeed/ALCF/tokenizer.model - export TOKENIZER_MODEL="${TOKENIZER_MODEL:-${tm}}" # USE TOKENIZER_MODEL from env, else fallback from ^ - export TOKENIZER_FLAGS="--tokenizer-type Llama2Tokenizer --tokenizer-model ${TOKENIZER_MODEL}" - if [[ "${TOKENIZER_TYPE}" != "GPT2" ]]; then - echo "Using tokenizer: ${TOKENIZER_TYPE}. Setting up data with ${DATA_FILE_LIST-}" - setData "${dfl}" || exit - fi + export TOKENIZER_TYPE="${TOKENIZER_TYPE:-Llama2Tokenizer}" + tm="${WORKING_DIR}/ALCF/tokenizer.model" # fallback: Megatron-DeepSpeed/ALCF/tokenizer.model + export TOKENIZER_MODEL="${TOKENIZER_MODEL:-${tm}}" # USE TOKENIZER_MODEL from env, else fallback from ^ + _tokenizer_flags+=( + "--tokenizer-type ${TOKENIZER_TYPE}" + "--tokenizer-model ${TOKENIZER_MODEL}" + ) + # if [[ "${TOKENIZER_TYPE}" != "GPT2" ]]; then + echo "Using tokenizer: ${TOKENIZER_TYPE}. Setting up data with ${DATA_FILE_LIST:-}" + setData "${dfl}" || exit fi + export DATA_FLAGS="${_data_flags[*]}" + export TOKENIZER_FLAGS="${_tokenizer_flags[*]}" printf "[setData] DATA_FLAGS: %s\n" "$(printGreen "${DATA_FLAGS}")" printf "[setData] TOKENIZER_FLAGS: %s\n" "$(printMagenta "${TOKENIZER_FLAGS}")" } @@ -1014,7 +1091,7 @@ setup_tokenizer_and_data() { # Ensure `DATA_FILE_LIST` is set, # fallback to default values if necessary. ############################################### -setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] +setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] ####### [Set DATA_FILE_LIST_FALLBACK based on current machine] ############# mn=$(get_machine_name) dfl_fallback="${WORKING_DIR}/ALCF/data-lists/${mn}/dolma.txt" @@ -1023,7 +1100,7 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] # use this data file list to call `setData` dfl="${1:-${dfl_fallback}}" printf "Calling: setData() with %s\n" "${dfl}" - ndocs=$(wc -l < "${dfl}") + ndocs=$(wc -l <"${dfl}") ws=$(sumWeights "${dfl}") dfl_stem=$(echo "${dfl}" | tr "\/" "\t" | awk '{print $NF}' | sed "s/\.txt//g") dcp=".cache/${dfl_stem}/index-cache" @@ -1032,7 +1109,7 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] export WEIGHT_SUM="${ws}" export DFL_STEM="${dfl_stem}" export DATA_CACHE_PATH="${dcp}" - export DATA_FLAGS="${DATA_FLAGS} --data-file-list ${DATA_FILE_LIST}" # --data-cache-path ${DATA_CACHE_PATH}" + # export DATA_FLAGS="${DATA_FLAGS} --data-file-list ${DATA_FILE_LIST}" # --data-cache-path ${DATA_CACHE_PATH}" echo "--------------------" echo "Updated environment:" printf "DATA_FILE_LIST: %s\n" "${DATA_FILE_LIST}" @@ -1044,6 +1121,30 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] echo "--------------------" } +generateDSconfig_new() { + cat < "${CONFIG_JSON}" + { + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "bf16": { + "enabled": true + }, + + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "wall_clock_breakdown" : false + } +EOT +} + ################################################################################ # generateDSconfig # @@ -1053,8 +1154,8 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] ################################################################################ generateDSconfig() { if [ $# -ne 1 ]; then - echo "Usage: $0 config_file" - exit 1 + echo "Usage: $0 config_file" + exit 1 fi for v in "$GLOBAL_BATCH" "$MICRO_BATCH" "$GRAD_ACC_STEPS" "$ZERO_STAGE" "$PP" "$DTYPE"; do if [ -z "$v" ]; then @@ -1062,16 +1163,6 @@ generateDSconfig() { exit 1 fi done - # \"optimizer\": { - # \"type\": \"AdamW\", - # \"params\": { - # \"lr\": ${LR}, - # \"beta1\": 0.9, - # \"beta2\": 0.95, - # \"eps\": 1e-5, - # \"weight_decay\": 1e-1 - # } - # }, # \"scheduler\": { # \"type\": \"WarmupLR\", # \"params\": { @@ -1086,13 +1177,17 @@ generateDSconfig() { \"train_micro_batch_size_per_gpu\": $MICRO_BATCH, \"steps_per_print\": 1, \"gradient_accumulation_steps\": $GRAD_ACC_STEPS, + \"zero_force_ds_cpu_optimizer\": false, \"zero_allow_untested_optimizer\": true, \"gradient_clipping\": 1.0, - \"activation_checkpointing\": { - \"partition_activations\": true, - \"contiguous_memory_optimization\": true - }, \"wall_clock_breakdown\": false," + if [[ "${USE_ACTIVATION_CHECKPOINTING}" == 1 ]]; then + activation_checkpointing="\ + \"activation_checkpointing\": { + \"partition_activations\": true, + \"contiguous_memory_optimization\": true + }," + fi flops_profiler="\ \"flops_profiler\": { \"enabled\": true, @@ -1133,8 +1228,22 @@ generateDSconfig() { else dtype="\"communication_data_type\": \"fp32\"," fi + if [[ "${OPT:-adamw}" == "ds.adamw" ]]; then + optimizer="\ + \"optimizer\": { + \"type\": \"AdamW\", + \"params\": { + \"lr\": ${LR}, + \"beta1\": 0.9, + \"beta2\": 0.95, + \"eps\": 1e-5, + \"weight_decay\": 1e-1 + }," + else + optimizer="" + fi if [[ "${ZERO_STAGE}" == 3 ]]; then - # \"mics_shard_size\": 2, + # \"mics_shard_size\": 2, zero="\ \"zero_optimization\": { \"stage\": 3, @@ -1157,9 +1266,8 @@ generateDSconfig() { } }," # elif [[ $ZERO_STAGE == 2 ]]; then - elif [[ "${ZERO_STAGE}" == 2 || "${ZERO_STAGE}" == 1 ]]; then - # if [[ -n "${CPU_OPTIMIZER}" ]]; then - if [[ "${CPU_OPTIMIZER:-0}" != 0 ]]; then + elif [[ "${ZERO_STAGE}" == 2 || "${ZERO_STAGE}" == 1 ]]; then + if [[ -z "${CPU_OPTIMIZER:-}" ]]; then echo "!!!! CAUGHT CPU_OPTIMIZER !!!!" zero="\ \"zero_optimization\": { @@ -1188,18 +1296,18 @@ generateDSconfig() { else extra="\ \"comms_logger\": { - \"enabled\": true, + \"enabled\": ${COMMS_LOGGER:-false}, \"verbose\": false, - \"prof_all\": true, \"debug\": false }," fi else echo 'Please add the correct config set!!!' fi -cat < "$1" + cat <"$1" { $common +$optimizer $zero $dtype $extra @@ -1243,7 +1351,7 @@ GREEN="\e[1;32m" YELLOW="\e[1;33m" BLUE="\e[1;34m" CYAN="\e[1;35m" -WHITE="\e[1;36m" +# WHITE="\e[1;36m" printBlack() { printf "\e[1;30m%s\e[0m\n" "$@" @@ -1277,6 +1385,73 @@ printWhite() { printf "\e[1;37m%s\e[0m\n" "$@" } +reset_env() { + custom_vars=( + NO_FLASH_ATTN + TP + PP + SP + FLASH_ARG + OPT + ADAM_BETA1 + ADAM_BETA2 + ADAM_EPS + WEIGHT_DECAY + HEADS + NLAYERS + HIDDEN + NUM_KV_HEAD + FFN_HIDDEN_SIZE + SEQ + ZERO_STAGE + MICRO_BATCH + EVAL_ITERS + EVAL_INTERVAL + TIMING_LOG_LEVEL + ACT_CKPT_NUM_LAYERS + USE_ACTIVATION_CHECKPOINTING + GLOBAL_BATCH_MAX + GLOBAL_BATCH + TRAIN_TOKENS + TRAIN_ITERS + MODEL_TYPE + LLAMA_ARGS + LR + LR_WARMUP_FRAC + LR_DECAY_ITERS + LR_ARGS + CPU_OPTIMIZER + DS_CONFIG + OUTPUT_DIR + OUTPUT_LOG + CKPT_DIR + ds_args + EXEC + EXEC_STEM + DATA_FLAGS + TOKENIZER_TYPE + TOKENIZER_MODEL + TOKENIZER_FLAGS + DATA_FILE_LIST + NUM_DOCS + WEIGHT_SUM + DFL_STEM + DATA_CACHE_PATH + DOTENV_FILE + YEAR + MONTH + DAY + TODAY + STARTED_AT + LAUNCHER + data_cache_path + DEFAULTS + ) + printf "Unsetting custom vars: %s\n" "${custom_vars[*]}" + unset "${custom_vars[@]}" +} + + ########################### # call helpers_main() ########################### diff --git a/ALCF/requirements/requirements.txt b/ALCF/requirements/requirements.txt index e0a969358d..03541ba514 100644 --- a/ALCF/requirements/requirements.txt +++ b/ALCF/requirements/requirements.txt @@ -15,3 +15,4 @@ six numpy<2 schedulefree packaging>=20.0 +wandb diff --git a/examples_deepspeed/finetune_hf_llama/ds_config.json b/examples_deepspeed/finetune_hf_llama/ds_config.json index 9c0b332473..85f439ce47 100755 --- a/examples_deepspeed/finetune_hf_llama/ds_config.json +++ b/examples_deepspeed/finetune_hf_llama/ds_config.json @@ -1,11 +1,5 @@ { "train_batch_size" : 256, "train_micro_batch_size_per_gpu": 16, - "steps_per_print": 100, - "zero_optimization": { - "stage": 0 - }, - "bf16": { - "enabled": true - } + "steps_per_print": 1 } diff --git a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh index c48ea11b93..ab8bfdf419 100644 --- a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh +++ b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh @@ -1,8 +1,8 @@ DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json -DATASET_PATH=./alpaca_data.json +DATASET_PATH=./examples_deepspeed/finetune_hf_llama/alpaca_data.json # dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json -HF_LLAMA_PATH=/data/llama-7b/ +HF_LLAMA_PATH=/data/llama-2-7b-hf/ # weights link: https://huggingface.co/huggyllama/llama-7b MICRO_BATCH_SIZE=16 @@ -44,11 +44,20 @@ cat < $DS_CONFIG EOT -covert_args="deepspeed tools/hf2megads_weight_converter.py \ +covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \ --hf-ckpt-num-shards 2 \ ---origin-hf-ckpt-dir $HF_LLAMA_PATH \ +--hf-ckpt-dir $HF_LLAMA_PATH \ +--load-mode auto \ --save $MEGA_DS_LLAMA_PATH" +covert_mds2hf_args="deepspeed tools/hf2megads_weight_converter.py \ +--hf-ckpt-num-shards 2 \ +--hf-ckpt-dir $HF_LLAMA_PATH \ +--load-mode auto \ +--to-hf-ckpt \ +--load $MEGA_DS_LLAMA_PATH \ +--save $HF_LLAMA_PATH'-hf-out' " + finetune_args="deepspeed finetune_llama.py \ --load $MEGA_DS_LLAMA_PATH" @@ -98,8 +107,10 @@ comm_args="--tensor-model-parallel-size $TP \ --no-gradient-accumulation-fusion \ --repeated-dataloader" -if [ "$1" = "convert" ]; then - task_args="$covert_args" +if [ "$1" = "convert_hf2mds" ]; then + task_args="$covert_hf2mds_args" +elif [ "$1" = "convert_mds2hf" ]; then + task_args="$covert_mds2hf_args" else task_args="$finetune_args" fi diff --git a/examples_deepspeed/pretrain_llama2_distributed.sh b/examples_deepspeed/pretrain_llama2_distributed.sh index f275ea636a..4c790e8c19 100755 --- a/examples_deepspeed/pretrain_llama2_distributed.sh +++ b/examples_deepspeed/pretrain_llama2_distributed.sh @@ -41,6 +41,17 @@ GRAD_CLIP=1 # activation_checkpoint="true" activation_checkpoint="false" +LOG_TO_WANDB=0 +WANDB_ARGS= +if [ $LOG_TO_WANDB -eq 1 ] +then +WANDB_ARGS="\ + --wandb-project pretrain-llama2 \ + --wandb-exp-name exp0 \ + --wandb-save-dir ${BASE_PATH}/wandb \ + " +fi + # Below configuration required for llama model as per llama paper # --no-query-key-layer-scaling \ # --attention-dropout 0 \ @@ -53,7 +64,6 @@ activation_checkpoint="false" ###################################### - cat < $DS_CONFIG { "train_batch_size" : $GLOBAL_BATCH_SIZE, @@ -132,4 +142,5 @@ torchrun $DISTRIBUTED_ARGS \ --normalization rmsnorm \ --disable-bias-linear \ --num-key-value-heads $NUM_KV_HEADS \ + $WANDB_ARGS \ $ds_args diff --git a/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh b/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh index da028dc731..24bfa544d6 100644 --- a/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh +++ b/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh @@ -187,14 +187,6 @@ host="${HOSTNAME}" seed=1234 num_workers=0 -data_path="BookCorpusDataset_text_document" -if [ ! -f "BookCorpusDataset_text_document.bin" ]; then - wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.bin -fi -if [ ! -f "BookCorpusDataset_text_document.idx" ]; then - wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.idx -fi - vocab_path="gpt2-vocab.json" if [ ! -f "$vocab_path" ]; then wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json @@ -204,6 +196,24 @@ if [ ! -f "$merge_path" ]; then wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt fi + +data_path="BookCorpusDataset_text_document" +if [ ! -f "BookCorpusDataset_text_document.bin" ]; then + # Download the Bookcorpus dataset and convert to json + python preprocess_bookcorpus.py + + # Process the dataset + python ${dir}/../../tools/preprocess_data.py \ + --input ${data_path}.json \ + --output-prefix "BookCorpusDataset" \ + --vocab-file $vocab_path \ + --merge-file $merge_path \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --workers 32 \ + --append-eod +fi + prescale_grad="true" jobname="gpt_${model_size}B_tok${train_tokens_in_billion}B" jobname="${jobname}_lr${lr}_min${min_lr}_w${lr_warmup_tokens_in_million}M_d${lr_decay_tokens_in_billion}B_${lr_decay_style}" diff --git a/examples_deepspeed/sequence_parallel/preprocess_bookcorpus.py b/examples_deepspeed/sequence_parallel/preprocess_bookcorpus.py new file mode 100644 index 0000000000..c35a13ea68 --- /dev/null +++ b/examples_deepspeed/sequence_parallel/preprocess_bookcorpus.py @@ -0,0 +1,4 @@ +from datasets import load_dataset + +train_data = load_dataset('bookcorpus/bookcorpus', split='train') +train_data.to_json("BookCorpusDataset_text_document.json", lines=True) diff --git a/examples_deepspeed/universal_checkpointing/README.md b/examples_deepspeed/universal_checkpointing/README.md index 341b0d113f..281d320e99 100644 --- a/examples_deepspeed/universal_checkpointing/README.md +++ b/examples_deepspeed/universal_checkpointing/README.md @@ -10,12 +10,12 @@ This folder contains example scripts that demonstrate how to use Universal Check For ZeRO stage 1, we provide bash scripts for bf16 and fp16 training examples corresponding to the steps 1 and 3 above. The step 1 scripts launch a training run of TP=PP=DP=2 of 200 iterations that creates a checkpoint every 100 iterations. The step 3 scripts load a universal checkpoint of iteration 100 and resume training with TP=PP=2 and DP=1 for an additional 100 iterations. Users can modify these scripts to try out other save and resume 3D combinations (e.g., save TP=PP=DP=1 and resume TP=PP=DP=2). Tensorboard logs are created by both step 1 and 3 scripts to enable visual inspection of how well the loss curves of the initial and resumed training runs match, especially at iteration 101. 1. bf16: - * run_bf16.sh: step 1 - * run_universal_bf16.sh: step 3 + * megatron_gpt/run_bf16.sh: step 1 + * megatron_gpt/run_universal_bf16.sh: step 3 2. fp16: - * run_fp16.sh: step 1 - * run_universal_fp16.sh: step 3 + * megatron_gpt/run_fp16.sh: step 1 + * megatron_gpt/run_universal_fp16.sh: step 3 Please note that these scripts should be run from the root folder of the repo (i.e., two levels above this README). For illustration, here are the commands for running the bf16 example. @@ -41,22 +41,22 @@ NOTE: Make sure to update your `BASE_DATA_PATH` path in the `run_[bf16/fp16].sh` ### Step 1: Create ZeRO checkpoint ```bash - bash examples_deepspeed/universal_checkpointing/run_bf16.sh + bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_bf16.sh ``` -By default the script will create the checkpoints in folder `z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy` +By default the script will create the checkpoints in folder `z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy` ### Step 2: Convert ZeRO checkpoint of iteration 100 to Universal format Assuming the DeepSpeed source code is cloned into the home folder, the following command will generate universal checkpoint for iteration 100. ```bash python ${HOME}/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py \ - --input_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/global_step100 \ - --output_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/global_step100_universal + --input_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/global_step100 \ + --output_folder z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/global_step100_universal ``` Note that we chose to create the universal checkpoint in the same checkpoint folder as the ZeRO checkpoint. This maintains the normal checkpoint folder structure expected by the Megatron-DeepSpeed code, which makes it easy to load universal checkpoints with little/no script or code changes. For clarity, we show below the contents of the checkpoint folder after creation of the universal checkpoint. Note that the conversion script creates `global_step100_universal` folder and `latest_universal` file. ```bash -ls -l z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_toy/ +ls -l z1_uni_ckpt/checkpoints/gpt2/z1/bf16/tp2_pp2_dp2_sp1_toy/ total 48 drwxr-xr-x 2 user group 4096 Oct 21 08:51 global_step100 drwxr-xr-x 3 user group 4096 Oct 21 09:28 global_step100_universal @@ -69,7 +69,7 @@ drwxr-xr-x 2 user group 4096 Oct 21 09:01 global_step200 ### Step 3: Resume training with Universal checkpoint of iteration 100 ```bash -bash examples_deepspeed/universal_checkpointing/run_universal_bf16.sh +bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_bf16.sh ``` This resumption script effects the loading of universal checkpoint rather than the ZeRO checkpoint in the folder by passing `--universal-checkpoint` command line flag to the main training script (i.e., `pretrain_gpt.py`). @@ -77,13 +77,15 @@ Please see the corresponding [pull request](https://github.com/microsoft/Megatro Combining sequence parallelism with data parallelism is another good use case for universal checkpointing, see [sp pull request](https://github.com/microsoft/DeepSpeed/pull/4752) for example and visualization of matching loss values. +Notes: The model weights using the ```--no-pipeline-parallel``` parameter and the model weights not using the ```--no-pipeline-parallel``` parameter are currently not supported for mutual conversion. + ### TensorBoard Log Analysis The Universal Checkpointing example includes a TensorBoard analysis script that will generate `csv` files and `png` plots across the unviersal checkpointing training steps for comparison of training and validation loss curves. After Step 3 is completed, the script may be executed as follows: ```bash -bash examples_deepspeed/universal_checkpointing/run_tb_analysis.sh z1_uni_ckpt +bash examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt.sh z1_uni_ckpt ``` The script will output the following `csv` files: @@ -116,4 +118,25 @@ Repeat steps in ZeRO stage 1 training above with the following modifications to * Set ZERO_STAGE=2 * Add `--no-pipeline-parallel` flag to deepspeed options -## ZeRO stage 3 training (**Coming soon**) +## ZeRO stage 3 training +Repeat steps in ZeRO stage 1 training above with the following modifications to your job batch scripts: +* Set ZERO_STAGE=3 +* Add `--no-pipeline-parallel` flag to deepspeed options + +> **Note:** that the stage 3 universal checkpoint currently supports Data parallelism. + +Below is the visualization of the `png` files generated from ZeRO stage 3. + +
+ + + *Figure 1: Training LM loss curve for first 200 training steps of Step 1 (TP=1, PP=1, DP=4) and training steps 101 to 200 of Step 3 (TP=1, PP=1, DP=2), which was loaded using the Universal Checkpoint.* +
+ +
+ + + *Figure 2: Validation LM loss curve for first 200 training steps of Step 1 (TP=1, PP=1, DP=4) and training steps 101 to 200 of Step 3 (TP=1, PP=1, DP=2), which was loaded using the Universal Checkpoint.* +
+ + diff --git a/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_training_loss.png b/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_training_loss.png new file mode 100644 index 0000000000..4c6758e991 Binary files /dev/null and b/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_training_loss.png differ diff --git a/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_validation_loss.png b/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_validation_loss.png new file mode 100644 index 0000000000..30d6f72eb8 Binary files /dev/null and b/examples_deepspeed/universal_checkpointing/assets/image/uc_stage3_char_validation_loss.png differ diff --git a/examples_deepspeed/universal_checkpointing/llama/run_llama_bf16.sh b/examples_deepspeed/universal_checkpointing/llama/run_llama_bf16.sh new file mode 100644 index 0000000000..72e79d4f1f --- /dev/null +++ b/examples_deepspeed/universal_checkpointing/llama/run_llama_bf16.sh @@ -0,0 +1,175 @@ +#!/bin/bash +set -ex + +DIR=`pwd` +###################################### +# Change the below configurations here +BASE_PATH=dataset +DS_CONFIG=${BASE_PATH}/deepspeed.json +DATASET=${BASE_PATH}/my-gpt2_text_document +TOKENIZER_PATH=${BASE_PATH}/llama-7b/tokenizer.model # offical llama tokenizer.model + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 + +HIDDEN_SIZE=2048 # e.g. llama-13b: 5120 +FFN_HIDDEN_SIZE=5504 # e.g. llama-13b: 13824 +NUM_LAYERS=24 # e.g. llama-13b: 40 +NUM_HEADS=16 # e.g. llama-13b: 40 +SEQ=2048 + +LR_WARMUP_STEPS=2000 +WEIGHT_DECAY=0.1 +GRAD_CLIP=1 + +## Activation checkpointing saves GPU memory, but reduces training speed +# activation_checkpoint="true" +activation_checkpoint="false" + +ZERO_STAGE=1 +DTYPE="bf16" + +# 3D parallelism of training +TP=2 +PP=2 +DP=2 +SP=1 +WORLD_SIZE=$((TP*PP*DP*SP)) +GLOBAL_BATCH=32 +MICRO_BATCH=$((GLOBAL_BATCH/WORLD_SIZE)) +TRAIN_ITERS=250000 +LR=3e-4 +MIN_LR=3e-5 + +# Debug +DEBUG_MODE=1 +if [[ $DEBUG_MODE == 1 ]]; then + EXIT_INTERVAL=200 + SIZE_TAG="toy" +else + EXIT_INTERVAL=$TRAIN_ITERS + SIZE_TAG="big" +fi + +# 3D parallelism of checkpoint to load +LOAD_TP=$TP +LOAD_PP=$PP +LOAD_DP=$DP +LOAD_SP=$SP +RUN_TAG="save" + + +EXP_DIR="z${ZERO_STAGE}_uni_ckpt" +CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_${SIZE_TAG} +LOAD_CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${LOAD_TP}_pp${LOAD_PP}_dp${LOAD_DP}_sp${LOAD_SP}_${SIZE_TAG} +LOG_DIR="${EXP_DIR}/tensorboard/llama/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_${SIZE_TAG}_${RUN_TAG}" +mkdir -p $LOG_DIR + +# Below configuration required for llama model as per llama paper +# --no-query-key-layer-scaling \ +# --attention-dropout 0 \ +# --hidden-dropout 0 \ +# --use-rotary-position-embeddings \ +# --untie-embeddings-and-output-weights \ +# --swiglu \ +# --normalization rmsnorm \ +# --disable-bias-linear \ +###################################### + +cat < $DS_CONFIG +{ + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "bf16": { + "enabled": true + }, + + "wall_clock_breakdown" : false +} +EOT + +ds_args="" +ds_args=" --deepspeed ${ds_args}" +ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" +ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" + +if [ "${activation_checkpoint}" = "true" ]; then + ds_args="--deepspeed-activation-checkpointing ${ds_args}" + + ## old argument for recomputing the transformer layer + # ds_args="--checkpoint-activations ${ds_args}" + + ## new argument for recomputing the transformer layer + ds_args="--recompute-granularity full --recompute-method uniform ${ds_args}" + ## new argument for recomputing only the attention layer + # ds_args="--recompute-granularity selective ${ds_args}" +fi + +if [[ ${ZERO_STAGE} -gt 1 ]]; then +ds_args="${ds_args} \ + --no-pipeline-parallel" +fi + +options="\ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --ds-sequence-parallel-size $SP \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEADS \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --train-iters $TRAIN_ITERS \ + --save ${CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --data-path $DATASET \ + --data-impl mmap \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr $LR \ + --lr-decay-style cosine \ + --min-lr $MIN_LR \ + --weight-decay $WEIGHT_DECAY \ + --clip-grad $GRAD_CLIP \ + --lr-warmup-iters $LR_WARMUP_STEPS \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 1 \ + --save-interval 100 \ + --eval-interval 10 \ + --eval-iters 40 \ + --exit-interval ${EXIT_INTERVAL} \ + --${DTYPE} \ + --no-query-key-layer-scaling \ + --attention-dropout 0 \ + --hidden-dropout 0 \ + --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights \ + --swiglu \ + --normalization rmsnorm \ + --disable-bias-linear \ + --tensorboard-dir $LOG_DIR \ + $ds_args +" + +WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" + +echo ${options} +echo ${run_cmd} +eval ${run_cmd} diff --git a/examples_deepspeed/universal_checkpointing/llama/run_tb_analysis_llama.sh b/examples_deepspeed/universal_checkpointing/llama/run_tb_analysis_llama.sh new file mode 100755 index 0000000000..b807fb97a7 --- /dev/null +++ b/examples_deepspeed/universal_checkpointing/llama/run_tb_analysis_llama.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +OUTPUT_PATH=$1 + +if [ "$OUTPUT_PATH" == "" ]; then + OUTPUT_PATH="z1_uni_ckpt" +fi + +# Training Loss +python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ + --tb_dir $OUTPUT_PATH \ + --tb_event_key "lm-loss-training/lm loss" \ + --plot_name "uc_char_training_loss.png" \ + --plot_title "Llama 7B Universal Checkpointing - Training Loss" \ + +# Validation Loss +python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ + --tb_dir $OUTPUT_PATH \ + --tb_event_key "lm-loss-validation/lm loss validation" \ + --csv_name "val_" \ + --plot_name "uc_char_validation_loss.png" \ + --plot_title "Llama 7B Universal Checkpointing - Validation Loss" \ + --plot_y_label "Validation LM Loss" \ diff --git a/examples_deepspeed/universal_checkpointing/llama/run_universal_llama_bf16.sh b/examples_deepspeed/universal_checkpointing/llama/run_universal_llama_bf16.sh new file mode 100644 index 0000000000..334fa3eaf6 --- /dev/null +++ b/examples_deepspeed/universal_checkpointing/llama/run_universal_llama_bf16.sh @@ -0,0 +1,176 @@ +#!/bin/bash +set -ex + +DIR=`pwd` +###################################### +# Change the below configurations here +BASE_PATH=dataset +DS_CONFIG=${BASE_PATH}/deepspeed.json +DATASET=${BASE_PATH}/my-gpt2_text_document +TOKENIZER_PATH=${BASE_PATH}/llama-7b/tokenizer.model # offical llama tokenizer.model + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 + +HIDDEN_SIZE=2048 # e.g. llama-13b: 5120 +FFN_HIDDEN_SIZE=5504 # e.g. llama-13b: 13824 +NUM_LAYERS=24 # e.g. llama-13b: 40 +NUM_HEADS=16 # e.g. llama-13b: 40 +SEQ=2048 + +LR_WARMUP_STEPS=2000 +WEIGHT_DECAY=0.1 +GRAD_CLIP=1 + +## Activation checkpointing saves GPU memory, but reduces training speed +# activation_checkpoint="true" +activation_checkpoint="false" + +ZERO_STAGE=1 +DTYPE="bf16" + +# 3D parallelism of training +TP=2 +PP=2 +DP=1 +SP=1 +WORLD_SIZE=$((TP*PP*DP*SP)) +GLOBAL_BATCH=32 +MICRO_BATCH=$((GLOBAL_BATCH/WORLD_SIZE)) +TRAIN_ITERS=250000 +LR=3e-4 +MIN_LR=3e-5 + +# Debug +DEBUG_MODE=1 +if [[ $DEBUG_MODE == 1 ]]; then + EXIT_INTERVAL=200 + SIZE_TAG="toy" +else + EXIT_INTERVAL=$TRAIN_ITERS + SIZE_TAG="big" +fi + +# 3D parallelism of checkpoint to load +LOAD_TP=2 +LOAD_PP=2 +LOAD_DP=2 +LOAD_SP=1 +RUN_TAG="uni_load${LOAD_TP}_${LOAD_PP}_${LOAD_DP}_${LOAD_SP}" + + +EXP_DIR="z${ZERO_STAGE}_uni_ckpt" +CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_${SIZE_TAG} +LOAD_CHECKPOINT_PATH=${EXP_DIR}/checkpoints/llama/z${ZERO_STAGE}/$DTYPE/tp${LOAD_TP}_pp${LOAD_PP}_dp${LOAD_DP}_sp${LOAD_SP}_${SIZE_TAG} +LOG_DIR="${EXP_DIR}/tensorboard/llama/$DTYPE/tp${TP}_pp${PP}_dp${DP}_sp${SP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_${SIZE_TAG}_${RUN_TAG}" +mkdir -p $LOG_DIR + +# Below configuration required for llama model as per llama paper +# --no-query-key-layer-scaling \ +# --attention-dropout 0 \ +# --hidden-dropout 0 \ +# --use-rotary-position-embeddings \ +# --untie-embeddings-and-output-weights \ +# --swiglu \ +# --normalization rmsnorm \ +# --disable-bias-linear \ +###################################### + +cat < $DS_CONFIG +{ + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "bf16": { + "enabled": true + }, + + "wall_clock_breakdown" : false +} +EOT + +ds_args="" +ds_args=" --deepspeed ${ds_args}" +ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" +ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" + +if [ "${activation_checkpoint}" = "true" ]; then + ds_args="--deepspeed-activation-checkpointing ${ds_args}" + + ## old argument for recomputing the transformer layer + # ds_args="--checkpoint-activations ${ds_args}" + + ## new argument for recomputing the transformer layer + ds_args="--recompute-granularity full --recompute-method uniform ${ds_args}" + ## new argument for recomputing only the attention layer + # ds_args="--recompute-granularity selective ${ds_args}" +fi + +if [[ ${ZERO_STAGE} -gt 1 ]]; then +ds_args="${ds_args} \ + --no-pipeline-parallel" +fi + +options="\ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --ds-sequence-parallel-size $SP \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEADS \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --train-iters $TRAIN_ITERS \ + --save ${CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --data-path $DATASET \ + --data-impl mmap \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr $LR \ + --lr-decay-style cosine \ + --min-lr $MIN_LR \ + --weight-decay $WEIGHT_DECAY \ + --clip-grad $GRAD_CLIP \ + --lr-warmup-iters $LR_WARMUP_STEPS \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 1 \ + --save-interval 100 \ + --eval-interval 10 \ + --eval-iters 40 \ + --exit-interval ${EXIT_INTERVAL} \ + --${DTYPE} \ + --no-query-key-layer-scaling \ + --attention-dropout 0 \ + --hidden-dropout 0 \ + --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights \ + --swiglu \ + --normalization rmsnorm \ + --disable-bias-linear \ + --tensorboard-dir $LOG_DIR \ + --universal-checkpoint \ + $ds_args +" + +WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" + +echo ${options} +echo ${run_cmd} +eval ${run_cmd} diff --git a/examples_deepspeed/universal_checkpointing/run_bf16.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_bf16.sh similarity index 99% rename from examples_deepspeed/universal_checkpointing/run_bf16.sh rename to examples_deepspeed/universal_checkpointing/megatron_gpt/run_bf16.sh index 0953954222..07cbc30e72 100755 --- a/examples_deepspeed/universal_checkpointing/run_bf16.sh +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_bf16.sh @@ -3,7 +3,7 @@ DIR=`pwd` DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -BASE_DATA_PATH=datasets +BASE_DATA_PATH=dataset DATASET=${BASE_DATA_PATH}/my-gpt2_text_document VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt diff --git a/examples_deepspeed/universal_checkpointing/run_fp16.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_fp16.sh similarity index 99% rename from examples_deepspeed/universal_checkpointing/run_fp16.sh rename to examples_deepspeed/universal_checkpointing/megatron_gpt/run_fp16.sh index 691fa8a8e6..2f1b994079 100755 --- a/examples_deepspeed/universal_checkpointing/run_fp16.sh +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_fp16.sh @@ -3,7 +3,7 @@ DIR=`pwd` DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -BASE_DATA_PATH=datasets +BASE_DATA_PATH=dataset DATASET=${BASE_DATA_PATH}/my-gpt2_text_document VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt diff --git a/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt.sh similarity index 96% rename from examples_deepspeed/universal_checkpointing/run_tb_analysis.sh rename to examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt.sh index 7aa988a0a0..3a17d66750 100755 --- a/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt.sh @@ -16,7 +16,6 @@ python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_scrip --tb_event_key "lm-loss-training/lm loss" \ --plot_name "uc_char_training_loss.png" \ --plot_title "Megatron-GPT Universal Checkpointing - Training Loss" \ - --use_sns # Validation Loss python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ @@ -26,4 +25,3 @@ python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_scrip --plot_name "uc_char_validation_loss.png" \ --plot_title "Megatron-GPT Universal Checkpointing - Validation Loss" \ --plot_y_label "Validation LM Loss" \ - --use_sns diff --git a/examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt_plot_only.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt_plot_only.sh new file mode 100755 index 0000000000..0c3ea5399c --- /dev/null +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_tb_analysis_gpt_plot_only.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +OUTPUT_PATH=$1 + +if [ "$OUTPUT_PATH" == "" ]; then + OUTPUT_PATH="z1_uni_ckpt" +fi + +# Training Loss +python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ + --tb_dir $OUTPUT_PATH \ + --tb_event_key "lm-loss-training/lm loss" \ + --plot_name "uc_char_training_loss.png" \ + --plot_title "Megatron-GPT Universal Checkpointing - Training Loss" \ + --plot_only \ + --csv_dir "/workspace/uc/megatron/loss_csv" \ + +# Validation Loss +python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ + --tb_dir $OUTPUT_PATH \ + --tb_event_key "lm-loss-validation/lm loss validation" \ + --csv_name "val_" \ + --plot_name "uc_char_validation_loss.png" \ + --plot_title "Megatron-GPT Universal Checkpointing - Validation Loss" \ + --plot_y_label "Validation LM Loss" \ + --plot_only \ + --csv_dir "/workspace/uc/megatron/val_csv" \ diff --git a/examples_deepspeed/universal_checkpointing/run_universal_bf16.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_bf16.sh similarity index 99% rename from examples_deepspeed/universal_checkpointing/run_universal_bf16.sh rename to examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_bf16.sh index ef0e134cfc..4134b9df48 100755 --- a/examples_deepspeed/universal_checkpointing/run_universal_bf16.sh +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_bf16.sh @@ -3,7 +3,7 @@ DIR=`pwd` DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -BASE_DATA_PATH=datasets +BASE_DATA_PATH=dataset DATASET=${BASE_DATA_PATH}/my-gpt2_text_document VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt diff --git a/examples_deepspeed/universal_checkpointing/run_universal_fp16.sh b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_fp16.sh similarity index 99% rename from examples_deepspeed/universal_checkpointing/run_universal_fp16.sh rename to examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_fp16.sh index 1e207e422b..bb3a538951 100755 --- a/examples_deepspeed/universal_checkpointing/run_universal_fp16.sh +++ b/examples_deepspeed/universal_checkpointing/megatron_gpt/run_universal_fp16.sh @@ -3,7 +3,7 @@ DIR=`pwd` DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -BASE_DATA_PATH=datasets +BASE_DATA_PATH=dataset DATASET=${BASE_DATA_PATH}/my-gpt2_text_document VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt diff --git a/examples_deepspeed/universal_checkpointing/tb_analysis/arguments.py b/examples_deepspeed/universal_checkpointing/tb_analysis/arguments.py index 3dacb45d4e..ca80872ca0 100644 --- a/examples_deepspeed/universal_checkpointing/tb_analysis/arguments.py +++ b/examples_deepspeed/universal_checkpointing/tb_analysis/arguments.py @@ -17,3 +17,5 @@ parser.add_argument("--skip_csv", action='store_true', help="Skip generation of csv files") parser.add_argument("--use_sns", action='store_true', help="Use the SNS library to format plot") parser.add_argument("--csv_name", required=False, default="", type=str, help="Unique name for CSV files") +parser.add_argument("--plot_only", action='store_true', help="Plot only using csv files") +parser.add_argument("--csv_dir", required=False, type=str, help="Directory for csv files") diff --git a/examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py b/examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py index 337f6540ab..fbf9b6dd28 100644 --- a/examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py +++ b/examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py @@ -6,9 +6,10 @@ import os import re import pandas as pd +import csv import matplotlib.pyplot as plt from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from utils import get_analyzer, find_files +from utils import get_analyzer, find_files_prefix, find_files_suffix from arguments import parser args = parser.parse_args() @@ -18,8 +19,8 @@ sns.set() def main(): - target_affix = 'events.out.tfevents' - tb_log_paths = find_files(args.tb_dir, target_affix) + target_prefix = 'events.out.tfevents' + tb_log_paths = find_files_prefix(args.tb_dir, target_prefix) analyzer = get_analyzer(args.analyzer) @@ -41,6 +42,8 @@ def main(): df = pd.DataFrame({"step": x, "value": y}) df.to_csv(f"{args.csv_name}{analyzer.get_csv_filename()}.csv") + plt.grid(True) + if not args.skip_plot: plt.legend() plt.title(args.plot_title) @@ -48,5 +51,35 @@ def main(): plt.ylabel(args.plot_y_label) plt.savefig(args.plot_name) +def plot_csv(): + target_suffix = 'csv' + csv_log_files = find_files_suffix(args.csv_dir, target_suffix) + + analyzer = get_analyzer(args.analyzer) + + for csv_file in csv_log_files: + analyzer.set_names(csv_file) + + x, y = [], [] + with open(csv_file, 'r') as file: + reader = csv.reader(file) + for row in reader: + if row[1] == 'step': + continue + x.append(int(row[1])) # Assuming the first column contains x values + y.append(float(row[2])) # Assuming the second column contains y values + + plt.plot(x, y, label=f'{analyzer.get_label_name()}') + + plt.grid(True) + plt.legend() + plt.title(args.plot_title) + plt.xlabel(args.plot_x_label) + plt.ylabel(args.plot_y_label) + plt.savefig(args.plot_name) + if __name__ == "__main__": - main() + if args.plot_only: + plot_csv() + else: + main() diff --git a/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py b/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py index f5809c3dc1..20d46ff6a8 100644 --- a/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py +++ b/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py @@ -19,7 +19,7 @@ def set_names(self, path_name): tp, pp, dp, sp = match.groups() self._label_name = f"Training Run: TP: {tp}, PP: {pp}, DP: {dp}" - self._csv_name = f"uc_out_tp_{tp}_pp_{pp}_dp_{dp}_sp_{sp}" + self._csv_name = f"uc_out_tp{tp}_pp{pp}_dp{dp}_sp{sp}" def get_label_name(self): return self._label_name diff --git a/examples_deepspeed/universal_checkpointing/tb_analysis/utils.py b/examples_deepspeed/universal_checkpointing/tb_analysis/utils.py index 4bbbb3f2f0..db6624bbc4 100644 --- a/examples_deepspeed/universal_checkpointing/tb_analysis/utils.py +++ b/examples_deepspeed/universal_checkpointing/tb_analysis/utils.py @@ -7,13 +7,13 @@ from uc_analysis import UniversalCheckpointingAnalysis -def find_files(directory, file_affix): +def find_files_prefix(directory, file_prefix): """ - Searches for files with a specific affix in a directory using os.walk(). + Searches for files with a specific prefix in a directory using os.walk(). Args: directory (str): The path to the directory to search. - file_affix (str): The desired file affix. + file_prefix (str): The desired file prefix. Returns: list: A list of paths to matching files. @@ -21,10 +21,28 @@ def find_files(directory, file_affix): matching_paths = [] for root, _, files in os.walk(directory): for filename in files: - if root not in matching_paths and filename.lower().startswith(file_affix.lower()): + if root not in matching_paths and filename.lower().startswith(file_prefix.lower()): matching_paths.append(os.path.join(root)) return matching_paths +def find_files_suffix(directory, file_suffix): + """ + Searches for files with a specific suffix in a directory using os.walk(). + + Args: + directory (str): The path to the directory to search. + file_suffix (str): The desired file suffix. + + Returns: + list: A list of paths to matching files. + """ + matching_paths = [] + for root, _, files in os.walk(directory): + for filename in files: + if root not in matching_paths and filename.lower().endswith(file_suffix.lower()): + matching_paths.append(os.path.join(filename)) + return matching_paths + def get_analyzer(analyzer_name): if analyzer_name == 'universal_checkpointing': return UniversalCheckpointingAnalysis() diff --git a/examples_deepspeed/zero_bubble_pp/README.md b/examples_deepspeed/zero_bubble_pp/README.md new file mode 100644 index 0000000000..2290860783 --- /dev/null +++ b/examples_deepspeed/zero_bubble_pp/README.md @@ -0,0 +1,51 @@ +# Zero Bubble Pipeline Parallelism Tutorials + +This folder contains examples and tutorials to enable Zero Bubble Pipeline Parallelism ([Paper Link](https://arxiv.org/abs/2401.10241)). The key idea is to breaking a backward pass into a $B$ pass and $W$ pass. $B$ on one stage will only depend on the $B$ on its next stage, compared to depending on both $B$ and $W$ of in 1F1B. + +![BW Split](./bw_split.png) + +Currently supported zero bubble schedules: +* ZB-H1 + +## ZB-H1 + +![alt text](zbh1.png) + +As shown in the above image, the ZB-H1 schedule cuts pipeline buble of 1F1B to 1/3. + +### ZB-H1 and Its Variation +There're two versions of ZB-H1 implemented in Megatron-Deepspeed: an official version (the 2nd schedule in the above image) which does a uniform B-W split, and another variation (the 3rd schedule in image) that does B-W split only when necessary. We provide the variation version as the default implementation. + +In practice the variation version is more friendly to a synchonized communication implementation and combined usage with tensor parallelism. However it changes the ordering of applying weight update of different microbatches (E.g. for Device 4 in the image above, the ordering of applying weight update is 4->5->6->7->1->2->3->8), hence might result in slightly different loss curve. + + +### How to use + +Simply add the following flag to the options to enable ZB-H1: + +``` +--enable-zbh1-pipeline +``` +The default implementation is the variation version of ZB-H1 mentioned in [Previous Section](#zb-h1). + +If you want the bit-to-bit exact semantics when compared to 1F1B, you can use the following flag. It might be a bit slower than the default implementation. + +``` +--enable-zbh1-exact-semantics +``` + +### ZB-H1 Toy Example + +Here is a toy example for using **ZB-H1** inside DeepSpeed repo. + +Firstly you'll need to prepare some sample training data and change the `data_path` in `zbh1_pretrain_gpt_1.3b.sh`. Then under this folder, Run + +``` +bash zbh1_pretrain_gpt_1.3b.sh +``` + +## Benchmarks + +The implementation has been checked and verified on various setups such as ZeRO Stage 1, activation recomputation, flash attention, tensor parallel, data parallel and bf16. By approximate measure, ~10% acceleration was observed when microbatch count is twice the number of pipeline stages: + +![alt text](benchmark.png) \ No newline at end of file diff --git a/examples_deepspeed/zero_bubble_pp/benchmark.png b/examples_deepspeed/zero_bubble_pp/benchmark.png new file mode 100644 index 0000000000..be46817d75 Binary files /dev/null and b/examples_deepspeed/zero_bubble_pp/benchmark.png differ diff --git a/examples_deepspeed/zero_bubble_pp/bw_split.png b/examples_deepspeed/zero_bubble_pp/bw_split.png new file mode 100644 index 0000000000..1ced957b44 Binary files /dev/null and b/examples_deepspeed/zero_bubble_pp/bw_split.png differ diff --git a/examples_deepspeed/zero_bubble_pp/zbh1.png b/examples_deepspeed/zero_bubble_pp/zbh1.png new file mode 100644 index 0000000000..364ef368a3 Binary files /dev/null and b/examples_deepspeed/zero_bubble_pp/zbh1.png differ diff --git a/examples_deepspeed/zero_bubble_pp/zbh1_pretrain_gpt_1.3b.sh b/examples_deepspeed/zero_bubble_pp/zbh1_pretrain_gpt_1.3b.sh new file mode 100644 index 0000000000..cf5705d973 --- /dev/null +++ b/examples_deepspeed/zero_bubble_pp/zbh1_pretrain_gpt_1.3b.sh @@ -0,0 +1,367 @@ +#!/bin/bash +dir=`pwd` +############################################################################### +### Main configs +## GPT-3 models use 2K sequence length/context window +seq_len=2048 + + +## The "GPT-3 XXX" below are configs from GPT-3 paper +## https://arxiv.org/abs/2005.14165, choose based on +## your desired model size or build your own configs + + +## init_std is standard deviation for weight initialization. Usually larger +## model needs lower std. We used a heuristic equation of sqrt(1/3/hidden_size) +## from the MT-NLG 530B work (https://arxiv.org/pdf/2201.11990.pdf) + + +## We changed min_lr to a lower number (1.0e-6), which we found is able to +## provide better zero-shot eval results. + + +## GPT-3 Small 125M +# model_size=0.125 +# num_layers=12 +# hidden_size=768 +# num_attn_heads=12 +# global_batch_size=256 +# lr=6.0e-4 +# min_lr=1.0e-6 +# init_std=0.02 + + +## GPT-3 Medium 350M +# model_size=0.35 +# num_layers=24 +# hidden_size=1024 +# num_attn_heads=16 +# global_batch_size=256 +# lr=3.0e-4 +# min_lr=1.0e-6 +# init_std=0.018 + + +## GPT-3 Large 760M +# model_size=0.76 +# num_layers=24 +# hidden_size=1536 +# num_attn_heads=16 +# global_batch_size=256 +# lr=2.5e-4 +# min_lr=1.0e-6 +# init_std=0.015 + + +## GPT-3 XL 1.3B +model_size=1.3 +num_layers=24 +hidden_size=2048 +num_attn_heads=16 +global_batch_size=16 +lr=2.0e-4 +min_lr=1.0e-6 +init_std=0.013 + + +## GPT-3 2.7B +# model_size=2.7 +# num_layers=32 +# hidden_size=2560 +# num_attn_heads=32 +# global_batch_size=512 +# lr=1.6e-4 +# min_lr=1.0e-6 +# init_std=0.011 + + +## GPT-3 6.7B +# model_size=6.7 +# num_layers=32 +# hidden_size=4096 +# num_attn_heads=32 +# global_batch_size=1024 +# lr=1.2e-4 +# min_lr=1.0e-6 +# init_std=0.009 + + +## GPT-3 13B +# model_size=13 +# num_layers=40 +# hidden_size=5120 +# num_attn_heads=40 +# global_batch_size=1024 +# lr=1.0e-4 +# min_lr=1.0e-6 +# init_std=0.008 + + +## GPT-3 175B +# model_size=175 +# num_layers=96 +# hidden_size=12288 +# num_attn_heads=96 +# global_batch_size=1536 +# lr=0.6e-4 +# min_lr=1.0e-6 +# init_std=0.005 +############################################################################### +### Training duration configs +## The main termination condition, original GPT-3 paper trains for 300B tokens. +train_tokens_in_billion=300 +train_tokens=$((${train_tokens_in_billion} * 1000000000)) + + +## train_samples is another termination condition and also affect the number of +## data samples to be indexed. Since we want to reach the train_tokens +## above, and data efficiency techniques may change num tokens in some samples, +## so we just set this config large enough to make sure we have enough +## processed data and don't terminate by train_samples. +train_samples=$(( 300 * 1000000000 * 2 / ${seq_len} )) + + +## Another wall-clock time termination condition in minutes. Set it large +## enough to avoid undesired early termination. +exit_duration=30000000 +############################################################################### +### lr configs +## lr warmup and decay duration. +## Original GPT-3 paper uses 375M warmup tokens and 260B cosine decay tokens. +## Here we increase the warmup tokens to 3B since when batch size warmup is not +## used, there are more tokens per step. Thus we need to increase warmup tokens +## to make sure there are enough warmup steps, which is important for training +## stability. +lr_warmup_tokens_in_million=3000 +lr_warmup_tokens=$((${lr_warmup_tokens_in_million} * 1000000)) +## Here we changed the LR decay tokens to align with total train tokens, since +## related works (e.g., https://arxiv.org/abs/2203.15556) find that setting the +## learning rate schedule to match the number of training tokens results in the +## best final model quality +lr_decay_tokens_in_billion=${train_tokens_in_billion} +lr_decay_tokens=$((${lr_decay_tokens_in_billion} * 1000000000)) +lr_decay_style="cosine" +############################################################################### +### Parallelism configs +## Model parallelism, 1 is no MP +mp_size=1 + + +## Pipeline parallelism. To disable PP, set pp_size to 1 and no_pp to true. +## Note that currently both curriculum learning and random-LTD are NOT +## compatible with pipeline parallelism. +pp_size=8 +no_pp="false" + + +## ZeRO-based data parallelism, stage=0 will disable ZeRO +zero_stage=0 + + +## Total number of GPUs. ds_ssh is from DeepSpeed library. +num_gpus=$(($(ds_ssh nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)-2)) +num_gpus_pernode=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +num_node=$(( ${num_gpus} / ${num_gpus_pernode} )) + + +## Data parallel size. +dp_size=$(( ${num_gpus} / ${pp_size} / ${mp_size} )) + + +## Micro batch size per GPU +## Make sure that batch_size <= global_batch_size*pp_size*mp_size/num_gpus +## Reduce it manually if GPU OOM +# batch_size=$(( ${global_batch_size} / ${dp_size} )) +batch_size=1 +############################################################################### +### Misc configs +log_interval=1 +eval_iters=10100 +eval_interval=10100 +# num_save controls how frequent to save checkpoint. num_save=20 means that a +# checkpoint will be saved every 5% of training. For longer training you would +# want larger num_save to save more frequently, and vice versa. +num_save=1 +# estimated_train_iter=$((${train_tokens} / ${seq_len} / ${global_batch_size})) +# save_interval=$((${estimated_train_iter} / ${num_save})) +save_interval=10100 + + +## Activation checkpointing saves GPU memory, but reduces training speed +activation_checkpoint="false" +# activation_checkpoint="false" + + +## Whether or not log optimizer states (norms, max abs values) to tensorboard. +## This is not required for training and might save GPU memory when turned off. +log_optimizer_state="true" +############################################################################### +### Output and data configs +current_time=$(date "+%Y.%m.%d_%H.%M.%S") +host="${HOSTNAME}" +seed=1234 +num_workers=0 + + +## Public the Pile dataset, can be downloaded at +## https://mystic.the-eye.eu/public/AI/pile_neox/ or +## https://the-eye.eu/public/AI/pile_neox/ Change data_home to where you +## store the pile_text_document.bin and pile_text_document.idx. +data_home="/code" +data_path="${data_home}/gpt_data/my-gpt2_text_document" + + +vocab_path="gpt2-vocab.json" +if [ ! -f "$vocab_path" ]; then + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json +fi +merge_path="gpt2-merges.txt" +if [ ! -f "$merge_path" ]; then + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt +fi + + +prescale_grad="true" +jobname="gpt_${model_size}B_tok${train_tokens_in_billion}B" +jobname="${jobname}_lr${lr}_min${min_lr}_w${lr_warmup_tokens_in_million}M_d${lr_decay_tokens_in_billion}B_${lr_decay_style}" +jobname="${jobname}_gbs${global_batch_size}_mbs${batch_size}_g${num_gpus}" +if [[ $zero_stage -gt 0 ]]; then + jobname="${jobname}_z${zero_stage}" + prescale_grad="false" +fi +if [[ $mp_size -gt 1 ]]; then + jobname="${jobname}_mp${mp_size}" +fi +if [ "${no_pp}" = "false" ]; then + jobname="${jobname}_pp${pp_size}" +fi +jobname="${jobname}_seed${seed}_rebase" + + +username=$(whoami) +output_home="/blob/users/${username}/project/data_efficient_gpt" +log_path="${output_home}/log/" +checkpoint_path="${output_home}/checkpoint/${jobname}" +## Microsoft internal constraint: because tensorboard is logged by last rank, +## it's better to put the path in NFS instead of Blob. +tensorboard_dir="/vc_data/users/${username}/project/data_efficient_gpt/tensorboard/" +tensorboard_path="${tensorboard_dir}${jobname}_${host}_${current_time}" +mkdir -p ${log_path} +mkdir -p ${checkpoint_path} +mkdir -p ${tensorboard_path} +############################################################################### +data_options=" \ + --vocab-file ${vocab_path} \ + --merge-file ${merge_path} \ + --data-path ${data_path} \ + --data-impl mmap" + + +## If CL is used, make sure to set "--split" the same as what you used during +## offline data analysis&indexing. +megatron_options=" \ + --override-opt_param-scheduler \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --tensor-model-parallel-size ${mp_size} \ + --init-method-std ${init_std} \ + --lr-decay-tokens ${lr_decay_tokens} \ + --lr-warmup-tokens ${lr_warmup_tokens} \ + --micro-batch-size ${batch_size} \ + --exit-duration-in-mins ${exit_duration} \ + --global-batch-size ${global_batch_size} \ + --num-layers ${num_layers} \ + --hidden-size ${hidden_size} \ + --num-attention-heads ${num_attn_heads} \ + --seq-length ${seq_len} \ + --max-position-embeddings ${seq_len} \ + --train-tokens ${train_tokens} \ + --train-samples ${train_samples} \ + --lr ${lr} \ + --min-lr ${min_lr} \ + --lr-decay-style ${lr_decay_style} \ + --split 949,50,1 \ + --log-interval ${log_interval} \ + --eval-interval ${eval_interval} \ + --eval-iters ${eval_iters} \ + --save-interval ${save_interval} \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --hysteresis 2 \ + --num-workers ${num_workers} \ + --fp16 \ + --seed ${seed} \ + --load ${checkpoint_path} \ + --save ${checkpoint_path} \ + --no-async-tensor-model-parallel-allreduce \ + --tensorboard-queue-size 1 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-dir ${tensorboard_path}" + + +if [ "${activation_checkpoint}" = "true" ]; then +megatron_options="${megatron_options} \ + --checkpoint-activations" +fi + + +if [ "${log_optimizer_state}" = "true" ]; then +megatron_options="${megatron_options} \ + --log-optimizer-states-to-tensorboard" +fi + + +config_json="ds_config_gbs${global_batch_size}_mbs${batch_size}_log${log_interval}_zero${zero_stage}.json" +template_json="../rebase/ds_config_gpt_TEMPLATE.json" +sed "s/GBSIZE/${global_batch_size}/" ${template_json} \ + | sed "s/MBSIZE/${batch_size}/" \ + | sed "s/LOG_INTERVAL/${log_interval}/" \ + | sed "s/ZERO_STAGE/${zero_stage}/" \ + | sed "s/PRESCALE_GRAD/${prescale_grad}/" \ + > ${config_json} + + +deepspeed_options=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${zero_stage} \ + --enable-zbh1-pipeline \ + --enable-zbh1-exact-semantics \ + --pipeline-model-parallel-size ${pp_size}" + + +if [[ "${no_pp}" = "true" ]]; then +deepspeed_options="${deepspeed_options} \ + --no-pipeline-parallel" +fi + + +if [ "${activation_checkpoint}" = "true" ]; then +deepspeed_options="${deepspeed_options} \ + --deepspeed-activation-checkpointing" +fi + + +## When saving checkpoint to a storage with cache, their could be consistency +## issue of the pointer to latest checkpoint. Here we find the correct pointer +## and broadcast it to all nodes. +iteration_file="$checkpoint_path/latest_checkpointed_iteration.txt" +iteration_file_2="$checkpoint_path/latest" +iteration=0 +for (( node = 0; node <= num_node-1; node++ )) +do + if $(ssh -q worker-"$node" "test -f \"$iteration_file\""); then + local_iteration=$(ssh -q worker-"$node" cat $iteration_file) + iteration=$(( ${local_iteration} > ${iteration} ? ${local_iteration} : ${iteration} )) + fi +done +if [[ $iteration -gt 0 ]]; then + iteration_2="global_step${iteration}" + ds_ssh "echo $iteration > $iteration_file" + ds_ssh "echo $iteration_2 > $iteration_file_2" +fi + + +deepspeed ${dir}/../../pretrain_gpt.py ${megatron_options} ${data_options} ${deepspeed_options} 2>&1 | tee log_zbh1_exact.txt \ No newline at end of file diff --git a/mds_to_hf.py b/mds_to_hf.py index 9f2d87cc48..d91513ed8b 100644 --- a/mds_to_hf.py +++ b/mds_to_hf.py @@ -1,130 +1,106 @@ -# Usage : python mds_to_hf.py --mds_checkpoint --output_dir -# Tips : Do not run on login node. -# This script currently only takes care of tp=1. Takes a AuroraGPT Llama model trained with Megatron-DeepSpeed and converts to LLamaCausalForLM architecture from HuggingFace. +# Usage : python mds_to_hf.py --mds_checkpoint --output_dir --cache-dir /flare/Aurora_deployment/vsastry +# Tips : Do not run on login node. +# This script currently only takes care of tp=1. Takes a AuroraGPT Llama model trained with Megatron-DeepSpeed and converts to LLamaCausalForLM architecture from HuggingFace. import argparse import torch -import pdb import os from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - -def repeat_kv_wt(x, np): +def repeat_kv_wt(x,np): return torch.repeat_interleave(x, dim=0, repeats=np) - def Update_llama_config(Llama_config, mds_args): - if mds_args["swiglu"]: + if mds_args['swiglu']: Llama_config.hidden_act = "silu" - Llama_config.hidden_size = mds_args["hidden_size"] - Llama_config.intermediate_size = mds_args["ffn_hidden_size"] - Llama_config.max_position_embeddings = mds_args["max_position_embeddings"] - Llama_config.num_attention_heads = mds_args["num_attention_heads"] - Llama_config.num_hidden_layers = mds_args["num_layers"] - Llama_config.num_key_value_heads = mds_args["num_key_value_heads"] - Llama_config.rms_norm_eps = mds_args["layernorm_epsilon"] - Llama_config.rope_theta = mds_args["rope_theta"] - Llama_config.vocab_size = mds_args["padded_vocab_size"] + Llama_config.hidden_size = mds_args['hidden_size'] + Llama_config.intermediate_size = mds_args['ffn_hidden_size'] + Llama_config.max_position_embeddings = mds_args['max_position_embeddings'] + Llama_config.num_attention_heads = mds_args['num_attention_heads'] + Llama_config.num_hidden_layers = mds_args['num_layers'] + Llama_config.num_key_value_heads = mds_args['num_key_value_heads'] + Llama_config.rms_norm_eps = mds_args['layernorm_epsilon'] + Llama_config.rope_theta = mds_args['rope_theta'] + Llama_config.vocab_size = mds_args['padded_vocab_size'] + if mds_args['fp16'] == True: + Llama_config.torch_dtype = 'float16' + elif mds_args['bf16'] == True: + Llama_config.torch_dtype = 'bfloat16' return Llama_config if __name__ == "__main__": - from pathlib import Path - parser = argparse.ArgumentParser() - parser.add_argument("--mds_checkpoint", required=True) - parser.add_argument("--output_dir", required=True) + parser.add_argument('--mds_checkpoint', required=True) + parser.add_argument('--output_dir', required=True) + parser.add_argument('--cache_dir', required=True) args = parser.parse_args() # make output_dir if it does not exits. if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - filename = str(args.mds_checkpoint) - if not filename.split("/")[-1].startswith("mp_rank") and not filename.split("/")[ - -1 - ].endswith(".pt"): - assert "Provide the right file path, The file should be of format mp_rank_*.pt" + filename = str(args.mds_checkpoint) + if not filename.split("/")[-1].startswith('mp_rank') and not filename.split("/")[-1].endswith('.pt'): + assert ("Provide the right file path, The file should be of format mp_rank_*.pt") print(f"loading mds checkpoint {filename}") - - cache_dir = Path(os.getcwd()).joinpath(".cache", "hugging_face") - cache_dir.mkdir(exist_ok=True, parents=True) - - mds_model = torch.load(args.mds_checkpoint, map_location=torch.device("cpu")) - Llama_model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", cache_dir=cache_dir.as_posix() - ) - Llama_config = Llama_model.config - Updated_Llama_config = Update_llama_config(Llama_config, mds_model["args"].__dict__) - # save the updated config.json file - Updated_Llama_config.to_json_file(os.path.join(args.output_dir, "config.json")) + + mds_model = torch.load(args.mds_checkpoint,map_location=torch.device('cpu')) + Llama_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",cache_dir=args.cache_dir) + + Llama_config = Llama_model.config + Updated_Llama_config = Update_llama_config(Llama_config, mds_model['args'].__dict__) + # save the updated config.json file + Updated_Llama_config.to_json_file(os.path.join(args.output_dir,'config.json')) state_dict = {} - dim = mds_model["args"].__dict__["kv_channels"] - inv_freq = 1.0 / ( - mds_model["args"].__dict__["rope_theta"] - ** (torch.arange(0, dim, 2).float() / dim) - ) - hidden_size = mds_model["args"].__dict__["hidden_size"] - kv_dim = ( - mds_model["args"].__dict__["kv_channels"] - * mds_model["args"].__dict__["num_key_value_heads"] - ) - kv_groups = ( - mds_model["args"].__dict__["num_attention_heads"] - // mds_model["args"].__dict__["num_key_value_heads"] - ) - for layer_i in range(Updated_Llama_config.__dict__["num_hidden_layers"]): + dim = mds_model['args'].__dict__['kv_channels'] + inv_freq = 1.0 / (mds_model['args'].__dict__['rope_theta'] ** (torch.arange(0,dim, 2).float() / dim)) + hidden_size = mds_model['args'].__dict__['hidden_size'] + kv_dim = mds_model['args'].__dict__['kv_channels'] * mds_model['args'].__dict__['num_key_value_heads'] + kv_groups = mds_model['args'].__dict__['num_attention_heads'] // mds_model['args'].__dict__['num_key_value_heads'] + nkvheads = mds_model['args'].__dict__['num_key_value_heads'] + for layer_i in range(Updated_Llama_config.__dict__['num_hidden_layers']): # SELF ATTENTION layers. - # get the q, k, v weights separately. Keeping k and v at the GQA head dim, since the transformers/models/llama/modelling_utils will take care of it. - fused_qkv = mds_model["module"]["language_model"]["encoder"][ - f"layers.{layer_i}.self_attention.query_key_value.weight" - ] - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = fused_qkv[ - 0:hidden_size - ] - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = fused_qkv[ - hidden_size : hidden_size + kv_dim - ] - # state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = repeat_kv_wt(fused_qkv[hidden_size:hidden_size+kv_dim], kv_groups) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = fused_qkv[ - hidden_size + kv_dim : hidden_size + 2 * kv_dim - ] - # state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = repeat_kv_wt(fused_qkv[hidden_size+kv_dim:hidden_size+2*kv_dim],kv_groups) - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = mds_model[ - "module" - ]["language_model"]["encoder"][f"layers.{layer_i}.self_attention.dense.weight"] - # MLP Layers - fused_mlp = mds_model["module"]["language_model"]["encoder"][ - f"layers.{layer_i}.mlp.dense_h_to_4h.weight" - ] - chunked_mlp = torch.chunk(fused_mlp, 2, dim=0) + # get the q, k, v weights separately. Keeping k and v at the GQA head dim, since the transformers/models/llama/modelling_utils will take care of it. + fused_qkv = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.self_attention.query_key_value.weight"] + fused_reshape = fused_qkv.view(nkvheads,(kv_groups+2)*dim,hidden_size) + ex_q = fused_reshape[:,:kv_groups*dim,:] + con_q = ex_q.contiguous().view(-1, fused_reshape.size(2)) + + ex_k = fused_reshape[:,kv_groups*dim:(kv_groups+1)*dim,:] + con_k = ex_k.contiguous().view(-1, fused_reshape.size(2)) + + ex_v = fused_reshape[:,(kv_groups+1)*dim:(kv_groups+2)*dim,:] + con_v = ex_v.contiguous().view(-1, fused_reshape.size(2)) + + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = con_q + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = con_k + #state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = repeat_kv_wt(fused_qkv[hidden_size:hidden_size+kv_dim], kv_groups) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = con_v + #state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = repeat_kv_wt(fused_qkv[hidden_size+kv_dim:hidden_size+2*kv_dim],kv_groups) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.self_attention.dense.weight"] + + # MLP Layers + fused_mlp = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.mlp.dense_h_to_4h.weight"] + chunked_mlp = torch.chunk(fused_mlp,2,dim=0) state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = chunked_mlp[0] state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = chunked_mlp[1] - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = mds_model[ - "module" - ]["language_model"]["encoder"][f"layers.{layer_i}.mlp.dense_4h_to_h.weight"] - # LayerNorm weights and RoPe - state_dict[f"model.layers.{layer_i}.input_layernorm.weight"] = mds_model[ - "module" - ]["language_model"]["encoder"][f"layers.{layer_i}.input_layernorm.weight"] - state_dict[ - f"model.layers.{layer_i}.post_attention_layernorm.weight" - ] = mds_model["module"]["language_model"]["encoder"][ - f"layers.{layer_i}.post_attention_layernorm.weight" - ] + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.mlp.dense_4h_to_h.weight"] + + #LayerNorm weights and RoPe + state_dict[f"model.layers.{layer_i}.input_layernorm.weight"] = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.input_layernorm.weight"] + state_dict[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = mds_model['module']['language_model']['encoder'][f"layers.{layer_i}.post_attention_layernorm.weight"] state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq - # Get the non-encoder layer weights. - state_dict["model.embed_tokens.weight"] = mds_model["module"]["language_model"][ - "embedding" - ]["word_embeddings"]["weight"] - state_dict["model.norm.weight"] = mds_model["module"]["language_model"]["encoder"][ - "final_layernorm.weight" - ] - state_dict["lm_head.weight"] = mds_model["module"]["language_model"][ - "output_layer" - ]["weight"] - # Save the model in the hf output path. - torch.save(state_dict, os.path.join(args.output_dir, "pytorch_model.bin")) + # Get the non-encoder layer weights. + state_dict["model.embed_tokens.weight"] = mds_model['module']['language_model']['embedding']['word_embeddings']['weight'] + state_dict["model.norm.weight"] = mds_model['module']['language_model']['encoder']['final_layernorm.weight'] + state_dict["lm_head.weight"] = mds_model['module']['language_model']['output_layer']['weight'] + + # Save the model in the hf output path. + torch.save(state_dict, os.path.join(args.output_dir,"pytorch_model.bin")) + + + diff --git a/megatron/__init__.py b/megatron/__init__.py index d92a279ec6..4a7fe2cd82 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -9,6 +9,7 @@ from .global_vars import update_num_microbatches from .global_vars import get_tokenizer from .global_vars import get_tensorboard_writer +from .global_vars import get_wandb_writer from .global_vars import get_adlr_autoresume from .global_vars import get_timers from .initialize import initialize_megatron diff --git a/megatron/arguments.py b/megatron/arguments.py index 307e725e51..9ab3e40953 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Megatron arguments.""" @@ -44,6 +45,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_inference_args(parser) parser = _add_transformer_engine_args(parser) parser = _add_retro_args(parser) + parser = _add_profiler_args(parser) # Custom arguments. if extra_args_provider is not None: @@ -73,6 +75,12 @@ def validate_args(args, defaults={}): assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ ' ({}) is not divisible by tensor model parallel size ({})'.format( args.world_size, args.tensor_model_parallel_size) + # Zero bubble pipeline is defined on deepspeed's scheduler + if args.enable_zbh1_pipeline: + assert args.deepspeed, 'Use DeepSpeed to use zero-bubble H1 pipeline' + assert args.sequence_parallel == False, "Sequence Parallel not tested, proceed at own will by removing this line" + if args.enable_zbh1_exact_semantics: + assert args.enable_zbh1_pipeline, 'Exact semantics require ZBH1 pipeline enabled' # Pipeline model parallel size. args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size, @@ -95,8 +103,8 @@ def validate_args(args, defaults={}): args.ds_sequence_parallel_size assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ - 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size) + 'size ({}) times seqence parallel size ({})'.format(args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, args.ds_sequence_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size if args.rank == 0: print('using world size: {}, data-parallel-size: {}, ' @@ -391,7 +399,8 @@ def validate_args(args, defaults={}): args.async_tensor_model_parallel_allreduce = False if not args.use_dataset_only: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if deepspeed.accelerator.get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: raise RuntimeError( "Using sequence parallelism requires setting the environment variable " @@ -672,6 +681,9 @@ def _add_network_size_args(parser): help='Untie embeddings and output weights.'), group.add_argument('--embedding-weights-in-fp32', action='store_true', help='Cast word embedding weights to fp32 before embedding fwd.'), + group.add_argument('--kill-switch-file', type=str, default=None, + help='Location of kill switch file. ' + 'If found will automatically exit the program at runtime.') return parser @@ -740,6 +752,12 @@ def _add_logging_args(parser): group.add_argument('--log-world-size-to-tensorboard', action='store_true', help='Enable world size logging to tensorboard.') + group.add_argument('--wandb-project', type=str, default='', + help='The wandb project name. Ignore wandb by default.') + group.add_argument('--wandb-exp-name', type=str, default='', + help='The wandb experiment name.') + group.add_argument('--wandb-save-dir', type=str, default='', + help='Path to save the wandb results locally.') return parser @@ -762,6 +780,15 @@ def _add_regularization_args(parser): help='Weight decay increment function.') group.add_argument('--clip-grad', type=float, default=1.0, help='Gradient clipping based on global L2 norm.') + group.add_argument('--sophiag-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-beta2', type=float, default=0.95, + help='Second coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-rho', type=float, default=0.01, + help='SophiaG clipping threshhold') + group.add_argument('--adam-beta1', type=float, default=0.9, help='First coefficient for computing running averages ' 'of gradient and its square') @@ -835,6 +862,10 @@ def _add_training_args(parser): 'uniformly divided recompute unit, ' '2) block: the number of individual Transformer layers ' 'to recompute within each pipeline stage.') + group.add_argument('--enable-zbh1-pipeline', action='store_true', + help='Activate zero bubble pipeline parallelism schedule method') + group.add_argument('--enable-zbh1-exact-semantics', action='store_true', + help='Use an exact semantics for zbh1 schedule, might be slower than the default.') # deprecated # HACK: added back arguments because DeepSpeed still relies on the old @@ -924,6 +955,7 @@ def _add_training_args(parser): choices=[ 'adam', 'adamw', + 'sophiag', 'sgd', 'ds.fusedlamb', 'ipex.lamb', @@ -990,7 +1022,14 @@ def _add_training_args(parser): dest='gradient_accumulation_fusion') group.add_argument('--use-dataset-only', type=bool, required=False, default=False, help='If set to True, only use the megatron dataset for external trainer ') - group.add_argument('--profile', action='store_true', help='Enable Torch Profiler') + # group.add_argument('--profile', action='store_true', help='Enable Torch Profiler') + group.add_argument( + "--train-range-to-skip", + action="extend", + nargs="+", + type=int, + help=("Range of iters to skip during training. Must be in pairs."), + ) group.add_argument('--train-iters-to-skip', action="extend", nargs="+", type=str, help=( "Specific train iterations to skip when training. " @@ -1320,6 +1359,8 @@ def _add_data_args(parser): help='What type of tokenizer to use.') group.add_argument('--tokenizer-model', type=str, default=None, help='Sentencepiece tokenizer model.') + group.add_argument('--trust-remote-code', action='store_true', default=False, + help='To run HFTokenizer model from local path.') group.add_argument('--data-impl', type=str, default='infer', choices=['mmap', 'infer'], help='Implementation of indexed datasets.') @@ -1573,3 +1614,26 @@ def _add_distillation_args(parser): help='Directory containing a teacher model checkpoint.') return parser + + +def _add_profiler_args(parser): + group = parser.add_argument_group(title='profiling configuration') + + group.add_argument("--profile", + type=str, + default=None, + choices=['pt', 'pt-full'], + help="Enable profiling, pt-full gives call stack compared to pt") + + group.add_argument("--profile_steps", + type=str, + default='2,3', + help="Which steps to profile. Format: ,") + + group.add_argument("--profile-ranks", + type=int, + nargs='+', + default=None, + help="Which ranks to profile. Format: 0 1 2 3") + + return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index b7f4b30bde..a4f82ec9d3 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -239,7 +239,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): """Save a model checkpoint.""" args = get_args() assert args is not None - iteration = args.iteration + args_iter = args.iteration + if args_iter != iteration: + log.warning(f"{args.iteration=} != {iteration} passed to 'save_checkpoint'") save_lr_state_dict() diff --git a/megatron/core/pipeline_parallel/deepspeed_zbh1_engine.py b/megatron/core/pipeline_parallel/deepspeed_zbh1_engine.py new file mode 100644 index 0000000000..ba451d9705 --- /dev/null +++ b/megatron/core/pipeline_parallel/deepspeed_zbh1_engine.py @@ -0,0 +1,110 @@ +from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore + +from deepspeed.runtime.pipe.engine import PipelineEngine +from deepspeed.utils.timer import BACKWARD_MICRO_TIMER, \ + BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER +from deepspeed.runtime.utils import PartitionedTensor +from deepspeed.accelerator import get_accelerator + +import torch +from torch.cuda.amp import custom_bwd +from packaging import version + + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, + get_global_memory_buffer, +) + +def _exec_backward_only_pass(self, buffer_id): + assert self.optimizer is not None, "must provide optimizer during " \ + "init in order to use backward" + + self.mem_status('BEFORE BWD ONLY', reset_max=True) + from megatron.core.tensor_parallel.layers import LinearWithGradAccumulationAndAsyncCommunication + WeightGradStore.set_combine_bw(False) + # The last stage just runs backward on the loss using DeepSpeed's typical + # mechanisms. + if self.is_last_stage(): + super(PipelineEngine, self).backward(self.loss) + WeightGradStore.flush() + self.mem_status('AFTER BWD ONLY') + + WeightGradStore.set_combine_bw(True) + return + + outputs = self.pipe_buffers['outputs'][buffer_id] + + if self.wall_clock_breakdown(): + self.timers(BACKWARD_MICRO_TIMER).start() + self.timers(BACKWARD_GLOBAL_TIMER).start() + self.timers(BACKWARD_INNER_MICRO_TIMER).start() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).start() + + # Reconstruct if we previously partitioned the output. We must be + # careful to also restore the computational graph of the tensors we partitioned. + if self.is_pipe_partitioned: + if self.is_grad_partitioned: + if self.pipe_partition_output_meta_cache is None: + self.pipe_partition_output_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache, + local_part=outputs[1], + group=self.grid.get_slice_parallel_group()) + self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() + outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:]) + else: + # Already restored from partition + self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0] + outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:]) + + grad_tensors = self.grad_layer + if self.is_grad_partitioned: + if self.grad_partition_grad_layer_meta_cache is None: + self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu') + part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache, + local_part=self.grad_layer[1], + group=self.grid.get_slice_parallel_group()) + grad_tensors = (part_grad.full(), *grad_tensors[2:]) + part_grad = None + + if self.using_bf16_optimizer and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + self.optimizer.clear_lp_grads() + + # This handles either a single tensor or tuple of tensors. + + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.is_floating_point()] + assert len(out_tensors) == len(grad_tensors) + torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) + else: + torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + + + WeightGradStore.flush() + + if self.using_bf16_optimizer and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + self.optimizer.update_hp_grads(clear_lp_grads=False) + + # Free up the memory from the output of forward() + self.pipe_buffers['output_tensors'][buffer_id] = None + self.pipe_buffers['outputs'][buffer_id] = None + grad_tensors = None + + WeightGradStore.set_combine_bw(True) + + if self.wall_clock_breakdown(): + self.timers(BACKWARD_INNER_MICRO_TIMER).stop() + self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop() + self.timers(BACKWARD_MICRO_TIMER).stop() + self.timers(BACKWARD_GLOBAL_TIMER).stop() + +def _exec_weight_pass(self): + if self.using_bf16_optimizer: + # manually call because we don't call optimizer.backward() + self.optimizer.clear_lp_grads() + WeightGradStore.pop() + if self.using_bf16_optimizer: + self.optimizer.update_hp_grads(clear_lp_grads=False) \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/deepspeed_zbh1_schedule.py b/megatron/core/pipeline_parallel/deepspeed_zbh1_schedule.py new file mode 100644 index 0000000000..651aadce72 --- /dev/null +++ b/megatron/core/pipeline_parallel/deepspeed_zbh1_schedule.py @@ -0,0 +1,148 @@ +from deepspeed.runtime.pipe.schedule import PipeSchedule, PipeInstruction, BufferOpInstruction, \ + LoadMicroBatch, RecvActivation, SendActivation, RecvGrad, SendGrad, \ + ForwardPass, BackwardPass, ReduceGrads, ReduceTiedGrads, OptimizerStep +from megatron import get_args + +class ZeroBubbleH1Pipeline(PipeSchedule): + """A schedule for training a batch using hybrid parallelism. + + Pipeline parallelism is extracted through gradient accumulation and thus + convergence follows that of a data parallel approach with the same batch + size. + """ + + def steps(self): + num_warmup_microbatches = self.stages - self.stage_id + + forward = 0 + backward = 0 + weight = 0 + + # F section + for _ in range(num_warmup_microbatches - 1): + if forward == self.micro_batches: + continue + forward_id = self.get_buffer_id(forward) + forward += 1 + + cmds = [] + if not self.is_first_stage: + cmds.append(RecvActivation(forward_id)) + if self.is_first_stage or self.is_last_stage: + cmds.append(LoadMicroBatch(forward_id)) + cmds.append(ForwardPass(forward_id)) + if not self.is_last_stage: + cmds.append(SendActivation(forward_id)) + yield cmds + + # FB section + for _ in range(self.stage_id): + if forward == self.micro_batches: + continue + forward_id = self.get_buffer_id(forward) + backward_id = self.get_buffer_id(backward) + forward += 1 + backward += 1 + + cmds = [] + if not self.is_first_stage: + cmds.append(RecvActivation(forward_id)) + if self.is_first_stage or self.is_last_stage: + cmds.append(LoadMicroBatch(forward_id)) + cmds.append(ForwardPass(forward_id)) + if not self.is_last_stage: + cmds.append(RecvGrad(backward_id)) + cmds.append(SendActivation(forward_id)) + cmds.append(BackwardOnlyPass(backward_id)) + if not self.is_first_stage: + cmds.append(SendGrad(backward_id)) + yield cmds + + # FBW section + while forward < self.micro_batches: + forward_id = self.get_buffer_id(forward) + backward_id = self.get_buffer_id(backward) + forward += 1 + backward += 1 + weight += 1 + + cmds = [] + if not self.is_first_stage: + cmds.append(RecvActivation(forward_id)) + if self.is_first_stage or self.is_last_stage: + cmds.append(LoadMicroBatch(forward_id)) + cmds.append(ForwardPass(forward_id)) + if not self.is_last_stage: + cmds.append(RecvGrad(backward_id)) + cmds.append(SendActivation(forward_id)) + if self.is_first_stage: + cmds.append(BackwardPass(backward_id)) + elif forward == self.micro_batches: + cmds.append(BackwardOnlyPass(backward_id)) + cmds.append(SendGrad(backward_id)) + cmds.append(WeightPass()) + else: + if get_args().enable_zbh1_exact_semantics: + cmds.append(BackwardOnlyPass(backward_id)) + cmds.append(SendGrad(backward_id)) + cmds.append(WeightPass()) + else: + cmds.append(BackwardPass(backward_id)) + cmds.append(SendGrad(backward_id)) + yield cmds + + #BW section + while backward < self.micro_batches: + backward_id = self.get_buffer_id(backward) + backward += 1 + weight += 1 + + cmds = [] + if not self.is_last_stage: + cmds.append(RecvGrad(backward_id)) + if self.is_first_stage: + cmds.append(BackwardPass(backward_id)) + else: + cmds.append(BackwardOnlyPass(backward_id)) + cmds.append(SendGrad(backward_id)) + cmds.append(WeightPass()) + yield cmds + + #W section + while weight < self.micro_batches: + weight += 1 + yield [WeightPass()] + + yield [ReduceTiedGrads(), ReduceGrads(), OptimizerStep()] + + def get_buffer_id(self, microbatch_id): + num_warmup_microbatches = self.stages - self.stage_id + return microbatch_id % num_warmup_microbatches + + +##Additional Instruction classes +class BackwardOnlyPass(BufferOpInstruction): + """Compute a backward pass and accumulate gradients. + + Roughly: + + .. code-block:: python + + outputs = buffers['outputs'][buffer_id] + gradients = buffers['gradients'][buffer_id] + torch.autograd.backward(tensors=outputs, + grad_tensors=gradients, inputs = input_tensor) + """ + pass + +class WeightPass(PipeInstruction): + """Compute a weight pass and accumulate gradients. + + Roughly: + + .. code-block:: python + + torch.autograd.backward(tensors=outputs, + grad_tensors=gradients, inputs = model.parameters()) + """ + pass diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 020d25915a..3dd3299ae0 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch @@ -16,6 +17,8 @@ from torch.cuda.amp import custom_fwd, custom_bwd +from megatron import get_args + from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.parallel_state import ( @@ -233,6 +236,11 @@ def __init__(self, sequence_length, embedding_dim): def forward(self, position_ids): return self.local_embeddings(position_ids - self.offset) +def gradientUpdateFunction(total_input, grad_output, weight): + if weight.grad == None: + weight.grad = grad_output.t().matmul(total_input) + else: + weight.grad += grad_output.t().matmul(total_input) class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """See linear_with_grad_accumulation_and_async_allreduce""" @@ -278,6 +286,7 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, @staticmethod @custom_bwd def backward(ctx, grad_output): + args = get_args() input, weight = ctx.saved_tensors use_bias = ctx.use_bias @@ -359,7 +368,13 @@ def backward(ctx, grad_output): # grad_weight = None # else: # grad_weight = grad_output.t().matmul(total_input) - grad_weight = grad_output.t().matmul(total_input) + if args.enable_zbh1_pipeline: + from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore + WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.sequence_parallel: @@ -441,7 +456,8 @@ def linear_with_grad_accumulation_and_async_allreduce( ] if not linear_with_grad_accumulation_and_async_allreduce.warned: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if sequence_parallel: warnings.warn( "When using sequence parallelism it is recommended to set the " diff --git a/megatron/core/tensor_parallel/weight_grad_store.py b/megatron/core/tensor_parallel/weight_grad_store.py new file mode 100644 index 0000000000..bbd1aea533 --- /dev/null +++ b/megatron/core/tensor_parallel/weight_grad_store.py @@ -0,0 +1,34 @@ +import queue + +class WeightGradStore: + + cache = [] + weight_grad_queue = queue.Queue() + combine_bw = True + + @classmethod + def set_combine_bw(cls, combine_bw): + # For the following backward pass, combine W with B and skip next W. + cls.combine_bw = combine_bw + + @classmethod + def put(cls, total_input, grad_output, weight, func): + if cls.combine_bw == True: + func(total_input, grad_output, weight) + return + # Store the weight gradient computation of linear layers. + cls.cache.append((total_input, grad_output, weight, func)) + + @classmethod + def flush(cls): + # Collect all stored computations during backward as a W. + cls.weight_grad_queue.put(cls.cache) + cls.cache = [] + + @classmethod + def pop(cls): + # Execute a single W. + assert cls.weight_grad_queue.qsize() > 0 + stored_grads = cls.weight_grad_queue.get() + for total_input, grad_output, weight, func in stored_grads: + func(total_input, grad_output, weight) \ No newline at end of file diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index ba2e00b1ef..590a379971 100755 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -6,14 +6,20 @@ import os import time +import logging import numpy as np import torch from deepspeed.accelerator import get_accelerator -from megatron import print_rank_0 +# from megatron import print_rank_0 from megatron.core import mpu from megatron.utils import Profile, PerfTrace from mpi4py import MPI + +from megatron.utils import get_logger + +log = get_logger(__name__, rank_zero_only=True) + dlp = Profile("DATASET") class BlendableDataset(torch.utils.data.Dataset): @dlp.log @@ -43,7 +49,7 @@ def _build_indices(): helpers.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, self.size, torch.distributed.get_rank() == 0) - print_rank_0('> elapsed time for building blendable dataset indices: ' + log.info('> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)) return dataset_index, dataset_sample_index @@ -68,7 +74,7 @@ def _build_indices(): ' dataset, building indices on rank 0 ...', flush=True) dataset_index, dataset_sample_index = _build_indices() try: - print_rank_0(" > saving index map files") + log.info(" > saving index map files") start_time = time.time() os.makedirs(os.path.dirname(index_path), exist_ok=True) with open(desc_path, 'wt') as fd: @@ -76,7 +82,7 @@ def _build_indices(): np.save(index_path, dataset_index, allow_pickle=True) np.save(sample_index_path, dataset_sample_index, allow_pickle=True) - print_rank_0(f" > finished saving index map files in {time.time() - start_time} seconds") + log.info(f" > finished saving index map files in {time.time() - start_time} seconds") except OSError: print(f'There was an error trying to create the data cache directory ({data_cache_path})') print('or a file in it. This is set with the --data-cache-path argument. Please') @@ -93,7 +99,7 @@ def _build_indices(): torch.distributed.get_world_size() // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) // torch.distributed.get_world_size(group=mpu.get_sequence_parallel_group())): - print_rank_0("Data index creation unsuccessful, exiting.") + log.info("Data index creation unsuccessful, exiting.") exit() ''' torch.distributed.barrier(group=mpu.get_data_parallel_group()) @@ -101,13 +107,13 @@ def _build_indices(): torch.distributed.barrier(group=mpu.get_data_parallel_group()) start_time = time.time() - print_rank_0(f'> loading blendable dataset index: {index_path}') + log.info(f'> loading blendable dataset index: {index_path}') self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r') assert self.dataset_index.size == self.size - print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}') + log.info(f'> loading blendable dataset sample index: {sample_index_path}') self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r') assert self.dataset_sample_index.size == self.size - print_rank_0(f'> finished loading in {time.time() - start_time} seconds') + log.info(f'> finished loading in {time.time() - start_time} seconds') else: self.dataset_index, self.dataset_sample_index = _build_indices() @@ -119,7 +125,7 @@ def _build_indices(): raise RuntimeError('BlendedDataset size is improperly bounded') except IndexError: pass - print_rank_0('> size of blendable dataset: ' + log.info('> size of blendable dataset: ' '{} samples'.format(self.size)) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 8c32be7d8e..c412d02b31 100755 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -9,67 +9,96 @@ import numpy as np import torch from deepspeed.accelerator import get_accelerator -from megatron import print_rank_0, is_rank_0, get_args +from megatron import is_rank_0, get_args from megatron.core import mpu -from megatron.data import helpers +from megatron.data import helpers # type:ignore from megatron.data.blendable_dataset import BlendableDataset -from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_datasets_corpuses_weights_and_num_samples +from megatron.data.dataset_utils import ( + get_datasets_weights_and_num_samples, + get_datasets_corpuses_weights_and_num_samples, +) from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset -from megatron.utils import PerfTrace, Profile +from megatron.utils import PerfTrace, Profile, get_logger from mpi4py import MPI dlp = Profile("DATASET") +log = get_logger(__name__, rank_zero_only=True) + + @dlp.log -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - train_data_prefix=None, - valid_data_prefix=None, - test_data_prefix=None, - return_doc_ids=False, *, - data_cache_path=None): +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" if data_prefix: - print_rank_0("Single data path provided for train, valid & test") + log.debug("Single data path provided for train, valid & test") # Single dataset. if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) # Blending dataset. # Parse the values. - output = get_datasets_corpuses_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_corpuses_weights_and_num_samples( + data_prefix, train_valid_test_num_samples + ) prefixes, corpuses, weights, datasets_train_valid_test_num_samples = output corpus_list = sorted(set(corpuses)) train_num_samples, valid_num_samples, test_num_samples = map( - sum, - zip(*datasets_train_valid_test_num_samples) + sum, zip(*datasets_train_valid_test_num_samples) ) class DatasetBuilder: - ''' + """ This is for building individual dataset from each dataset file - ''' + """ + @dlp.log - def __init__(self, prefix, corpus, data_impl, splits_string, - num_samples, seq_length, seed, skip_warmup, - return_doc_ids, - data_cache_path=data_cache_path, name='train'): + def __init__( + self, + prefix, + corpus, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + return_doc_ids, + data_cache_path=data_cache_path, + name="train", + ): self.prefix = prefix self.data_impl = data_impl self.splits_string = splits_string - if name == 'train': + if name == "train": self.num_samples = num_samples[0] - elif name == 'valid': + elif name == "valid": self.num_samples = num_samples[1] else: self.num_samples = num_samples[2] @@ -84,11 +113,21 @@ def __init__(self, prefix, corpus, data_impl, splits_string, self.desc = prefix + f"{self.num_samples}" + f"{seq_length}" + f"{seed}" self.build = False self.corpus = corpus + @dlp.log def Build(self): - self.dataset = _build_train_valid_test_datasets_single(self.prefix, self.data_impl, self.splits_string, - self.num_samples_train_valid_test, self.seq_length, self.seed, self.skip_warmup, self.name, self.return_doc_ids, - data_cache_path=self.data_cache_path) + self.dataset = _build_train_valid_test_datasets_single( + self.prefix, + self.data_impl, + self.splits_string, + self.num_samples_train_valid_test, + self.seq_length, + self.seed, + self.skip_warmup, + self.name, + self.return_doc_ids, + data_cache_path=self.data_cache_path, + ) self.build = True return self.dataset @@ -98,21 +137,27 @@ def __init__(self, dataset_builders, shuffle=False): self.dataset_builders = dataset_builders self.num_datasets = len(dataset_builders) self.num_samples = np.sum([d.num_samples for d in dataset_builders]) - self.indices=np.zeros((self.num_samples, 2), dtype=np.uint64) - self.desc="ConcatDataset:" - m = 0 + self.indices = np.zeros((self.num_samples, 2), dtype=np.uint64) + self.desc = "ConcatDataset:" + # m = 0 num_samples_list = np.array([d.num_samples for d in dataset_builders]) self.num_samples = np.sum(num_samples_list) + def _build_indices(): start_time = time.time() dataset_index = np.zeros(self.num_samples, dtype=np.int64) dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64) - helpers.build_concat_indices(dataset_index, dataset_sample_index, - num_samples_list, - self.num_datasets, - torch.distributed.get_rank()==0) - print_rank_0('> elapsed time for building concat dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + helpers.build_concat_indices( + dataset_index, + dataset_sample_index, + num_samples_list, + self.num_datasets, + torch.distributed.get_rank() == 0, + ) + log.debug( + "> elapsed time for building concat dataset indices: " + "{:.2f} (sec)".format(time.time() - start_time) + ) return dataset_index, dataset_sample_index self.dataset_index, self.dataset_sample_index = _build_indices() @@ -123,7 +168,12 @@ def _build_indices(): for i in range(self.num_datasets): self.desc += dataset_builders[i].prefix + "," - self.desc += f"-{self.num_samples}" + f"-{dataset_builders[0].seq_length}" + f"{dataset_builders[0].seed}" + self.desc += ( + f"-{self.num_samples}" + + f"-{dataset_builders[0].seq_length}" + + f"{dataset_builders[0].seed}" + ) + def __len__(self): return self.num_samples @@ -136,227 +186,340 @@ def __getitem__(self, idx): return self.dataset_builders[i].dataset[j] else: return self.dataset_builders[i].Build()[j] - - # Predetermine whether need to build the specific dataset or not. + # Predetermine whether need to build the specific dataset or not. start_time = time.time() - print_rank_0(" >>> Started building datasets in distributed way ... ") + log.debug(" >>> Started building datasets in distributed way ... ") a, b, c = [int(d) for d in splits_string.split(",")] - + train_datasets = [] valid_datasets = [] test_datasets = [] # Build individual datasets. args = get_args() @dlp.log - def build_corpus_datasets(dataset_type='train'): + def build_corpus_datasets(dataset_type="train"): start_time = time.time() - print_rank_0(f" >>> Building {dataset_type} corpus datasets ...") + log.debug(f" >>> Building {dataset_type} corpus datasets ...") datasets = [] corpus_builders = {} corpus_weights = {} for c in corpus_list: corpus_builders[c] = [] corpus_weights[c] = 0.0 - dataset_builders = [DatasetBuilder(prefixes[i], corpuses[i], data_impl, splits_string, - datasets_train_valid_test_num_samples[i], - seq_length, seed, skip_warmup, - return_doc_ids,data_cache_path, dataset_type) for i in range(len(weights))] - for i in range(torch.distributed.get_rank()//mpu.get_tensor_model_parallel_world_size(), len(weights), torch.distributed.get_world_size()//mpu.get_tensor_model_parallel_world_size()): + dataset_builders = [ + DatasetBuilder( + prefixes[i], + corpuses[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + return_doc_ids, + data_cache_path, + dataset_type, + ) + for i in range(len(weights)) + ] + for i in range( + torch.distributed.get_rank() + // mpu.get_tensor_model_parallel_world_size(), + len(weights), + torch.distributed.get_world_size() + // mpu.get_tensor_model_parallel_world_size(), + ): dataset_builders[i].Build() - print_rank_0(f" >>> Finished building individual datasets in {time.time() - start_time} seconds") + log.debug( + f" >>> Finished building individual datasets in {time.time() - start_time} seconds" + ) start_concating_time = time.time() for i, d in zip(range(len(weights)), dataset_builders): corpus_builders[d.corpus].append(d) corpus_weights[d.corpus] += weights[i] total = 0 - print_rank_0(" > number of samples for each corpus ") - corpus_weights_achieved={} + log.debug(" > number of samples for each corpus ") + corpus_weights_achieved = {} for c in corpus_list: datasets.append(BuildConcatDataset(corpus_builders[c], args.shuffle_sample)) total += datasets[-1].num_samples - corpus_weights_achieved[c] = float(datasets[-1].num_samples)/train_num_samples - print_rank_0(f" {c}: {datasets[-1].num_samples} w={corpus_weights_achieved[c]} (expected: {corpus_weights[c]})") - - print_rank_0(f" > total number of samples: {total}") - print_rank_0(f" >>> Finished concatenating datasets in {time.time() - start_concating_time} seconds") - print_rank_0(f" >>> Finished building {dataset_type} corpus datasets in {time.time() - start_time} seconds") + corpus_weights_achieved[c] = ( + float(datasets[-1].num_samples) / train_num_samples + ) + log.debug( + f" {c}: {datasets[-1].num_samples} w={corpus_weights_achieved[c]} (expected: {corpus_weights[c]})" + ) + + log.debug(f" > total number of samples: {total}") + log.debug( + f" >>> Finished concatenating datasets in {time.time() - start_concating_time} seconds" + ) + log.debug( + f" >>> Finished building {dataset_type} corpus datasets in {time.time() - start_time} seconds" + ) return datasets, [corpus_weights_achieved[c] for c in corpus_list] + train_weights = None if a > 0: - train_datasets, train_weights = build_corpus_datasets('train') - + train_datasets, train_weights = build_corpus_datasets("train") + valid_weights = None if b > 0: - valid_datasets, valid_weights = build_corpus_datasets('valid') - - if c > 0: - test_datasets, test_weights = build_corpus_datasets('test') + valid_datasets, valid_weights = build_corpus_datasets("valid") + test_weights = None + if c > 0: + test_datasets, test_weights = build_corpus_datasets("test") # This barrier is critical to make sure that all the datasets are built once # and the metadata were written to the cache folder before other ranks touch them - print_rank_0(f" >>> Rank 0 - finished building datasets in {time.time() - start_time} seconds") + log.debug( + f" >>> Rank 0 - finished building datasets in {time.time() - start_time} seconds" + ) torch.distributed.barrier(group=mpu.get_data_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_data_parallel_group()) - print_rank_0(f" >>> Finished building datasets (all ranks) in distributed way in {time.time() - start_time} seconds") - print_rank_0(f" >>> Starting to build BlendableDataset") + log.debug( + f" >>> Finished building datasets (all ranks) in distributed way in {time.time() - start_time} seconds" + ) + log.debug(" >>> Starting to build BlendableDataset") # Blend. start_time = time.time() blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, train_weights, train_num_samples, - data_cache_path=data_cache_path) + if train_datasets and train_weights: + blending_train_dataset = BlendableDataset( + train_datasets, + train_weights, + train_num_samples, + data_cache_path=data_cache_path, + ) blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, valid_weights, valid_num_samples, - data_cache_path=data_cache_path) + if valid_datasets and valid_weights: + blending_valid_dataset = BlendableDataset( + valid_datasets, + valid_weights, + valid_num_samples, + data_cache_path=data_cache_path, + ) blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, test_weights, test_num_samples, - data_cache_path=data_cache_path) + if test_datasets and test_weights: + blending_test_dataset = BlendableDataset( + test_datasets, + test_weights, + test_num_samples, + data_cache_path=data_cache_path, + ) end_time = time.time() - print_rank_0(f" >>> Finished building BlendableDataset in {end_time - start_time} seconds") - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + log.debug( + f" >>> Finished building BlendableDataset in {end_time - start_time} seconds" + ) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) else: - print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") + log.debug( + "Separate data paths provided for train, valid & test. Split string will be ignored." + ) train_dataset, valid_dataset, test_dataset = None, None, None # Single dataset. if train_data_prefix is not None: - train_dataset = build_dataset("train", train_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[0], - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + train_dataset = build_dataset( + "train", + train_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[0], + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) if valid_data_prefix is not None: - valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[1], - seq_length, seed, False, - data_cache_path=data_cache_path) - + valid_dataset = build_dataset( + "valid", + valid_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[1], + seq_length, + seed, + False, + data_cache_path=data_cache_path, + ) if test_data_prefix is not None: - test_dataset = build_dataset("test", test_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[2], - seq_length, seed, False, - data_cache_path=data_cache_path) + test_dataset = build_dataset( + "test", + test_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[2], + seq_length, + seed, + False, + data_cache_path=data_cache_path, + ) return (train_dataset, valid_dataset, test_dataset) + @dlp.log -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - return_doc_ids=False, *, - data_cache_path=None): +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - print_rank_0(' > dataset split:') + log.debug(" > dataset split:") def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + log.debug(" {}:".format(name)) + log.debug( + " document indices in [{}, {}) total of {} " "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_length, seed, - return_doc_ids, - data_cache_path=data_cache_path) + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, + seed, + return_doc_ids, + data_cache_path=data_cache_path, + ) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) + @dlp.log -def _build_train_valid_test_datasets_single(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, name, - return_doc_ids=False, *, - data_cache_path=None): +def _build_train_valid_test_datasets_single( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + name, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" # Each rank print out information - print_rank_0(f" >> building dataset for {data_prefix}") + log.debug(f" >> building dataset for {data_prefix}") # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - print_rank_0(' > dataset split:') + log.debug(" > dataset split:") def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + log.debug(" {}:".format(name)) + log.debug( + " document indices in [{}, {}) total of {} " "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_length, seed, - return_doc_ids, - data_cache_path=data_cache_path) + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, + seed, + return_doc_ids, + data_cache_path=data_cache_path, + ) return dataset - if name.find("train")!=-1: - return build_dataset(0, 'train') - if name.find("valid")!=-1: - return build_dataset(1, 'valid') - if name.find("test")!=-1: - return build_dataset(2, 'test') + + if name.find("train") != -1: + return build_dataset(0, "train") + if name.find("valid") != -1: + return build_dataset(1, "valid") + if name.find("test") != -1: + return build_dataset(2, "test") + @dlp.log -def build_dataset(dataset_name, data_prefix, data_impl, - splits_string, num_samples, - seq_length, seed, skip_warmup, - *, - data_cache_path=None): +def build_dataset( + dataset_name, + data_prefix, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + *, + data_cache_path=None, +): dataset = None if len(data_prefix) == 1: - dataset = _build_dataset(dataset_name, data_prefix[0], data_impl, - splits_string, num_samples, seq_length, - seed, skip_warmup, - data_cache_path=data_cache_path) + dataset = _build_dataset( + dataset_name, + data_prefix[0], + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) else: # Blending dataset. # Parse the values. @@ -367,73 +530,108 @@ def build_dataset(dataset_name, data_prefix, data_impl, # Build individual datasets. datasets = [] for i in range(len(prefixes)): - ds = _build_dataset(dataset_name, prefixes[i], data_impl, - splits_string, dataset_num_samples[i], - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + ds = _build_dataset( + dataset_name, + prefixes[i], + data_impl, + splits_string, + dataset_num_samples[i], + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) if ds: datasets.append(ds) if datasets: - dataset = BlendableDataset(datasets, weights, num_samples, - data_cache_path=data_cache_path) + dataset = BlendableDataset( + datasets, weights, num_samples, data_cache_path=data_cache_path + ) return dataset + @dlp.log -def _build_dataset(dataset_name, data_prefix, data_impl, splits_string, - num_samples, seq_length, seed, skip_warmup, - *, - data_cache_path=None): +def _build_dataset( + dataset_name, + data_prefix, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + *, + data_cache_path=None, +): """ Build dataset. This method is called when individual train, valid, test datasets are provided """ # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] - print_rank_0(' {}:'.format(dataset_name)) - print_rank_0(' document indices in [0, {}) total of {} ' - 'documents'.format(total_num_of_documents, total_num_of_documents)) - - documents = np.arange(start=0, stop=total_num_of_documents, - step=1, dtype=np.int32) - - dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset, - splits_string, num_samples, seq_length, seed, - data_cache_path=data_cache_path) + log.debug(" {}:".format(dataset_name)) + log.debug( + " document indices in [0, {}) total of {} " "documents".format( + total_num_of_documents, total_num_of_documents + ) + ) + + documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) + + dataset = GPTDataset( + dataset_name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + data_cache_path=data_cache_path, + ) return dataset + @dlp.log def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): """Build indexed dataset.""" - print_rank_0(' > building dataset index ...') + log.debug(" > building dataset index ...") start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) - print_rank_0(' > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time)) - print_rank_0(' number of documents: {}'.format( - indexed_dataset.sizes.shape[0])) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + log.debug( + " > finished creating indexed dataset in {:4f} " "seconds".format( + time.time() - start_time + ) + ) + log.debug(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) return indexed_dataset class GPTDataset(torch.utils.data.Dataset): @dlp.log - def __init__(self, name, data_prefix, documents, indexed_dataset, - splits_string, num_samples, seq_length, seed, - return_doc_ids=False, *, - data_cache_path=None): - + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + return_doc_ids=False, + *, + data_cache_path=None, + ): self.name = name self.indexed_dataset = indexed_dataset self.return_doc_ids = return_doc_ids @@ -443,20 +641,29 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ - _build_index_mappings(self.name, data_prefix, - documents, self.indexed_dataset.sizes, - splits_string, num_samples, seq_length, seed, - data_cache_path=data_cache_path) - + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = ( + _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + splits_string, + num_samples, + seq_length, + seed, + data_cache_path=data_cache_path, + ) + ) def __len__(self): # -1 is due to data structure used to retieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) return self.sample_idx.shape[0] - 1 + @dlp.log def __getitem__(self, idx): args = get_args() + assert args is not None orig_idx = idx # Get the shuffled index. try: @@ -465,21 +672,24 @@ def __getitem__(self, idx): if is_rank_0(): import json from rich import print_json + print(exc) print( - '\n'.join( - ['-------------------------------------------------', - f'Trying to access {idx=} from self.shuffle_idx,', - f'but {len(self.shuffle_idx)=}', - '-------------------------------------------------'] + "\n".join( + [ + "-------------------------------------------------", + f"Trying to access {idx=} from self.shuffle_idx,", + f"but {len(self.shuffle_idx)=}", + "-------------------------------------------------", + ] ) ) print_json( json.dumps( { - 'doc_idx': len(self.doc_idx), - 'sample_idx': len(self.sample_idx), - 'shuffle_idx': len(self.shuffle_idx), + "doc_idx": len(self.doc_idx), + "sample_idx": len(self.sample_idx), + "shuffle_idx": len(self.shuffle_idx), }, indent=4, ) @@ -493,45 +703,57 @@ def __getitem__(self, idx): doc_ids = [] if doc_index_f == doc_index_l: doc_ids.append(self.doc_idx[doc_index_f]) - sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1) + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) else: # Otherwise, get the rest of the initial document. doc_ids.append(self.doc_idx[doc_index_f]) - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f)] + sample_list = [ + self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): doc_ids.append(self.doc_idx[i]) sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) # And finally add the relevant portion of last document. doc_ids.append(self.doc_idx[doc_index_l]) - sample_list.append(self.indexed_dataset.get( - self.doc_idx[doc_index_l], - length=offset_l + 1)) + sample_list.append( + self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) sample = np.concatenate(sample_list) - text_name = 'text' + text_name = "text" if args.use_dataset_only: - text_name = 'input_ids' + text_name = "input_ids" sample_dict = {text_name: np.array(sample, dtype=np.int64)} if args.return_data_index: - sample_dict.update({'index': np.array([orig_idx], dtype=np.int64)}) + sample_dict.update({"index": np.array([orig_idx], dtype=np.int64)}) - if self.return_doc_ids: # for retro preprocessing - sample_dict.update({'doc_ids': np.array(doc_ids, dtype=np.int64)}) + if self.return_doc_ids: # for retro preprocessing + sample_dict.update({"doc_ids": np.array(doc_ids, dtype=np.int64)}) if args.use_dataset_only: - sample_dict.update({'labels': np.array(sample, dtype=np.int64)}) + sample_dict.update({"labels": np.array(sample, dtype=np.int64)}) return sample_dict + @dlp.log -def _build_index_mappings(name, data_prefix, documents, sizes, - splits_string, num_samples, seq_length, seed, - *, - data_cache_path): +def _build_index_mappings( + name, + data_prefix, + documents, + sizes, + splits_string, + num_samples, + seq_length, + seed, + *, + data_cache_path, +): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. sample-idx: is the start document index and document offset for each @@ -539,10 +761,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes, shuffle-idx: maps the sample index into a random index into sample-idx. """ args = get_args() + assert args is not None # Number of tokens in each epoch and number of required epochs. tokens_per_epoch = _num_tokens(documents, sizes) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) - if args.train_data_exact_num_epochs is not None and name == 'train': + if args.train_data_exact_num_epochs is not None and name == "train": num_epochs = args.train_data_exact_num_epochs # rng state @@ -557,13 +780,13 @@ def _build_index_mappings(name, data_prefix, documents, sizes, desc += f"Sequence length {seq_length}\n" desc += f"Random seed {seed}\n" desc += f"Split {splits_string}\n" - desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest() desc_filename = desc_hash + ".dsc" - doc_idx_filename = desc_hash + '_doc_idx.npy' - sample_idx_filename = desc_hash + '_sample_idx.npy' - shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + doc_idx_filename = desc_hash + "_doc_idx.npy" + sample_idx_filename = desc_hash + "_sample_idx.npy" + shuffle_idx_filename = desc_hash + "_shuffle_idx.npy" - if name == 'train': + if name == "train": # force to use certain index files if args.train_desc_path is not None: desc_filename = args.train_desc_path @@ -578,15 +801,15 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # duplication, then look in data-cache-path if specified, # If nothing is found, use the last path looked in build_indices = True - prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + prefixes = [os.path.join(os.path.dirname(data_prefix), "index-cache")] if data_cache_path is not None: prefixes.append(data_cache_path) for prefix in prefixes: idx_path = { - 'desc': os.path.join(prefix, desc_filename), - 'doc': os.path.join(prefix, doc_idx_filename), - 'sample': os.path.join(prefix, sample_idx_filename), - 'shuffle': os.path.join(prefix, shuffle_idx_filename) + "desc": os.path.join(prefix, desc_filename), + "doc": os.path.join(prefix, doc_idx_filename), + "sample": os.path.join(prefix, sample_idx_filename), + "shuffle": os.path.join(prefix, shuffle_idx_filename), } for f in idx_path.values(): if not os.path.isfile(f): @@ -595,15 +818,17 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Found our files! build_indices = False break - data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_dir = os.path.dirname(idx_path["desc"]) data_cache_success = True # Build the indexed mapping if not exist. if build_indices: - # Since this function will be called by all the rank in the very beginning. Therefore, we assume that all the - # ranks will first create the document files, and then read it. + # Since this function will be called by all the rank in the very beginning. Therefore, we assume that all the + # ranks will first create the document files, and then read it. # There will not be contension effects going on either - print_rank_0(f" > WARNING: could not find index map files, building on rank {torch.distributed.get_rank()}") + log.warning( + f" > WARNING: could not find index map files, building on rank {torch.distributed.get_rank()}" + ) # For the last epoch, decide whether include the entire epoch # in the global shuffle or not. @@ -612,64 +837,80 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # not mean anything. if num_epochs == 1: separate_last_epoch = False - print_rank_0(' > only one epoch required, setting ' - 'separate_last_epoch to False') + log.debug( + " > only one epoch required, setting " "separate_last_epoch to False" + ) else: # Get the number of samples for the last epoch num_samples_from_epochs_minus_one = ( - (num_epochs - 1) * tokens_per_epoch - 1) // seq_length - last_epoch_num_samples = num_samples - \ - num_samples_from_epochs_minus_one - assert last_epoch_num_samples >= 0, \ - 'last epoch number of samples should be non-negative.' + (num_epochs - 1) * tokens_per_epoch - 1 + ) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert ( + last_epoch_num_samples >= 0 + ), "last epoch number of samples should be non-negative." num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ - 'last epoch number of samples exceeded max value.' + assert last_epoch_num_samples <= ( + num_samples_per_epoch + 1 + ), "last epoch number of samples exceeded max value." # If we have less than 80% of the samples for the last epoch, # seperate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can # be adjusted if needed. - separate_last_epoch = (last_epoch_num_samples < - int(0.80 * num_samples_per_epoch)) + separate_last_epoch = last_epoch_num_samples < int( + 0.80 * num_samples_per_epoch + ) if separate_last_epoch: - string = ' > last epoch number of samples ({}) is smaller '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to True' + string = ( + " > last epoch number of samples ({}) is smaller " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to True" + ) else: - string = ' > last epoch number of samples ({}) is larger '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to False' - print_rank_0(string.format(last_epoch_num_samples, - num_samples_per_epoch)) - + string = ( + " > last epoch number of samples ({}) is larger " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to False" + ) + log.debug(string.format(last_epoch_num_samples, num_samples_per_epoch)) try: os.makedirs(data_cache_dir, exist_ok=True) # description - with open(idx_path['desc'], 'wt') as fd: + with open(idx_path["desc"], "wt") as fd: fd.write(desc) # doc-idx. start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng, - separate_last_epoch) - np.save(idx_path['doc'], doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(idx_path["doc"], doc_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # sample-idx. start_time = time.time() # Use C++ implementation for speed. # First compile and then import. from megatron.data import helpers + assert doc_idx.dtype == np.int32 assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch, torch.distributed.get_rank()==0) - np.save(idx_path['sample'], sample_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + sample_idx = helpers.build_sample_idx( + sizes, + doc_idx, + seq_length, + num_epochs, + tokens_per_epoch, + torch.distributed.get_rank() == 0, + ) + np.save(idx_path["sample"], sample_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # shuffle-idx. start_time = time.time() # -1 is due to data structure used to retieve the index: @@ -678,35 +919,46 @@ def _build_index_mappings(name, data_prefix, documents, sizes, num_samples_ = num_samples_from_epochs_minus_one else: num_samples_ = sample_idx.shape[0] - 1 - shuffle_idx = _build_shuffle_idx(num_samples_, - sample_idx.shape[0] - 1, np_rng) - np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) + shuffle_idx = _build_shuffle_idx( + num_samples_, sample_idx.shape[0] - 1, np_rng + ) + np.save(idx_path["shuffle"], shuffle_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) except OSError: - print(f'There was an error trying to create the data cache directory ({data_cache_dir})') - print('or a file in it. This defaults to a directory "index-cache" within the directory') - print('the data files are in and can be set with the --data-cache-path argument. Please') - print('ensure you have write access to this directory or specify one that you do have') - print('write access to.') + print( + f"There was an error trying to create the data cache directory ({data_cache_dir})" + ) + print( + 'or a file in it. This defaults to a directory "index-cache" within the directory' + ) + print( + "the data files are in and can be set with the --data-cache-path argument. Please" + ) + print( + "ensure you have write access to this directory or specify one that you do have" + ) + print("write access to.") data_cache_success = False # Load mappings. start_time = time.time() - print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") - doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path["doc"], allow_pickle=True, mmap_mode="r") - print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") - sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path["sample"], allow_pickle=True, mmap_mode="r") - print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") - shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path["shuffle"], allow_pickle=True, mmap_mode="r") - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - sample_idx.shape[0])) - print_rank_0(' total number of epochs: {}'.format(num_epochs)) + log.debug( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + log.debug(" total number of samples: {}".format(sample_idx.shape[0])) + log.debug(" total number of epochs: {}".format(num_epochs)) return doc_idx, sample_idx, shuffle_idx, desc, desc_hash @@ -730,25 +982,26 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples): if ((total_tokens - 1) // seq_length) >= num_samples: return num_epochs + @dlp.log def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): """Build an array with length = number-of-epochs * number-of-dcuments. Each index is mapped to a corresponding document.""" if not separate_last_epoch or num_epochs == 1: - doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] doc_idx[:] = documents doc_idx = doc_idx.reshape(-1) doc_idx = doc_idx.astype(np.int32) np_rng.shuffle(doc_idx) return doc_idx - doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) return np.concatenate((doc_idx_first, doc_idx_last)) + @dlp.log -def _build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch): +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): """Sample index mapping is a 2D array with sizes [number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is the @@ -782,7 +1035,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length, # Note that -1 here is for the same reason we have -1 in # `_num_epochs` calculations. if remaining_seq_length <= 0: - doc_offset += (remaining_seq_length + doc_length - 1) + doc_offset += remaining_seq_length + doc_length - 1 remaining_seq_length = 0 else: # Otherwise, start from the begining of the next document. @@ -795,24 +1048,28 @@ def _build_sample_idx(sizes, doc_idx, seq_length, return sample_idx + @dlp.log def _build_shuffle_idx(num_samples, total_size, np_rng): """Build the range [0, size) and shuffle.""" - print_rank_0(' > building shuffle index with split [0, {}) and [{}, {}) ' - '...'.format(num_samples, num_samples, total_size)) + log.debug( + " > building shuffle index with split [0, {}) and [{}, {}) " "...".format( + num_samples, num_samples, total_size + ) + ) dtype_ = np.uint32 if total_size >= (np.iinfo(np.uint32).max - 1): dtype_ = np.int64 - shuffle_idx_first = np.arange(start=0, stop=num_samples, - step=1, dtype=dtype_) + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) np_rng.shuffle(shuffle_idx_first) if num_samples == total_size: return shuffle_idx_first - shuffle_idx_last = np.arange(start=num_samples, stop=total_size, - step=1, dtype=dtype_) + shuffle_idx_last = np.arange( + start=num_samples, stop=total_size, step=1, dtype=dtype_ + ) np_rng.shuffle(shuffle_idx_last) return np.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1eb9b7842b..8479f33fab 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -15,17 +15,24 @@ from functools import lru_cache import os + +# import logging import shutil import struct from itertools import accumulate import numpy as np import torch -from megatron import print_rank_0 -from megatron.utils import Profile + +# from megatron import print_rank_0 +from megatron.utils import Profile, get_logger + +log = get_logger(__name__) + dlp = Profile("DATASET") + def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 @@ -34,28 +41,32 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['lazy', 'cached', 'mmap'] + return ["lazy", "cached", "mmap"] def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': - return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) else: return IndexedDatasetBuilder(out_file) @@ -63,22 +74,24 @@ def make_builder(out_file, impl, vocab_size=None): def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -114,11 +127,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -131,38 +144,41 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() self.path = path self.data_file = None self.read_index(path) + @dlp.log def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -176,7 +192,7 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -185,7 +201,7 @@ def __getitem__(self, idx): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -205,8 +221,8 @@ def size(self, index): @staticmethod def exists(path): - return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) ) @property @@ -215,7 +231,6 @@ def supports_prefetch(self): class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -224,6 +239,7 @@ def __init__(self, path): @property def supports_prefetch(self): return True + @dlp.log def prefetch(self, indices): if all(i in self.cache_index for i in indices): @@ -240,7 +256,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx: ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -255,10 +271,10 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx: ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -278,15 +294,17 @@ class IndexedDatasetBuilder(object): np.float32: 4, np.float64: 8, } + @dlp.log def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] self.doc_idx = [0] + @dlp.log def add_item(self, tensor): bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) @@ -297,6 +315,7 @@ def add_item(self, tensor): def end_document(self): self.doc_idx.append(len(self.sizes)) + @dlp.log def merge_file_(self, another_file): index = IndexedDataset(another_file) @@ -315,7 +334,7 @@ def merge_file_(self, another_file): self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -325,21 +344,22 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack(' 0 + if args.enable_zbh1_pipeline: + deepspeed.runtime.pipe.schedule.TrainSchedule = ZeroBubbleH1Pipeline + deepspeed.runtime.pipe.engine.PipelineEngine._INSTRUCTION_MAP.update( + { + BackwardOnlyPass: _exec_backward_only_pass, + WeightPass: _exec_weight_pass, + } + ) # Call the init process if args.deepspeed or args.ds_inference: deepspeed.init_distributed() diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 141c901ffa..f2beea06e8 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -1,22 +1,36 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# type:ignore +# noqa: E401,E402,F401 -# from deepspeed.accelerator.real_accelerator import get_accelerator -# if get_accelerator().device_name() == 'cuda': -try: +import torch +from deepspeed.accelerator.real_accelerator import get_accelerator + +accelerator = get_accelerator() + +if accelerator is not None and accelerator.device_name() == "xpu": + import intel_extension_for_pytorch # noqa: F401 # type: ignore + +if accelerator is not None and accelerator.device_name() == "cuda": from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm - from apex.normalization import MixedFusedRMSNorm as RMSNorm - HAS_APEX = True -except Exception: - HAS_APEX = False - from .rmsnorm import RMSNorm - from torch.nn import LayerNorm -# else: -# from .rmsnorm import RMSNorm -# from torch.nn import LayerNorm - -from .distributed import DistributedDataParallel -from .bert_model import BertModel -from .gpt_model import GPTModel, GPTModelPipe -from .t5_model import T5Model -from .language_model import get_language_model -from .module import Float16Module + + try: + from apex.normalization import MixedFusedRMSNorm as RMSNorm # type:ignore + + HAS_APEX = True + except Exception: + HAS_APEX = False + from .rmsnorm import RMSNorm +else: + if hasattr(torch.xpu, "IpexRmsNorm"): + from .fused_rmsnorm import RMSNorm + else: + from .rmsnorm import RMSNorm # noqa:E401,E402,F401 + from torch.nn import LayerNorm # noqa:E401,E402,F401 + + +from .distributed import DistributedDataParallel # noqa:E401,E402,F401 +from .bert_model import BertModel # noqa:E401,E402,F401 +from .gpt_model import GPTModel, GPTModelPipe # noqa:E401,E402,F401 +from .t5_model import T5Model # noqa:E401,E402,F401 +from .language_model import get_language_model # noqa:E401,E402,F401 +from .module import Float16Module # noqa:E401,E402,F401 diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 2f3b89014b..d1ef034397 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -1,9 +1,11 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """This code is copied fron NVIDIA apex: https://github.com/NVIDIA/apex with some changes. """ +from deepspeed.accelerator.real_accelerator import get_accelerator import numbers import torch from torch.nn.parameter import Parameter @@ -13,6 +15,7 @@ import inspect from megatron.core.utils import make_viewless_tensor +from megatron import get_args try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN @@ -56,8 +59,15 @@ def __init__(self, normalized_shape, eps=1e-5, normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + init_device = None + if get_accelerator().device_name() == 'hpu': + init_device = get_accelerator().current_device_name() + self.weight = Parameter(torch.empty(*normalized_shape, + device=init_device, + dtype=get_args().params_dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, + device=init_device, + dtype=get_args().params_dtype)) self.reset_parameters() self.no_persist_layer_norm = no_persist_layer_norm self.sequence_parallel = sequence_parallel diff --git a/megatron/model/fused_rmsnorm.py b/megatron/model/fused_rmsnorm.py new file mode 100644 index 0000000000..d69b0822a4 --- /dev/null +++ b/megatron/model/fused_rmsnorm.py @@ -0,0 +1,24 @@ +from megatron import get_args + +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +import intel_extension_for_pytorch as ipex # noqa + + +# Taken from facebookresearch/llama +class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False + ): + super().__init__() + self.eps = eps + self.weight = Parameter(torch.ones(dim, dtype=get_args().params_dtype)) + self.sequence_parallel = sequence_parallel + setattr(self.weight, "sequence_parallel", self.sequence_parallel) + + def forward(self, x): + output = torch.xpu.IpexRmsNorm( + x, self.weight.shape, self.weight, self.eps + ) + return output diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 8968c96655..e5e60c43ee 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """GPT-2 model.""" @@ -393,9 +394,12 @@ def _to_float16(inputs): if args.normalization == 'layernorm': self.specs.append(LayerSpec(LayerNorm, args.hidden_size, - eps=args.layernorm_epsilon)) + eps=args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) else: - self.specs.append(LayerSpec(RMSNorm, args.hidden_size, args.layernorm_epsilon)) + self.specs.append(LayerSpec(RMSNorm, args.hidden_size, + args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) def _logits_helper(embedding, lm_output): """A wrapper to massage inputs/outputs from pipeline. """ diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index ceef5be725..eebf8744ca 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Transformer based language model.""" @@ -256,8 +257,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. if self.sequence_parallel: - # already partition sequence, do not need scatter_to_sequence_parallel_region - # embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # already partition sequence, do not need scatter_to_sequence_parallel_region ? + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: diff --git a/megatron/model/module.py b/megatron/model/module.py index 963ad2d29d..08cf7cb553 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Megatron Module""" @@ -10,10 +11,9 @@ from megatron.core import mpu, tensor_parallel -_FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype] -_HALF_TYPES = [get_accelerator().HalfTensor(0).dtype] -_BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype] - +_FLOAT_TYPES = None +_HALF_TYPES = None +_BF16_TYPES = None def param_is_not_shared(param): @@ -131,6 +131,9 @@ def conversion_helper(val, conversion): def fp32_to_float16(val, float16_convertor): """Convert fp32 `val` to fp16/bf16""" + global _FLOAT_TYPES + if _FLOAT_TYPES is None: + _FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype] def half_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): @@ -143,6 +146,11 @@ def half_conversion(val): def float16_to_fp32(val): """Convert fp16/bf16 `val` to fp32""" + global _HALF_TYPES, _BF16_TYPES + if _HALF_TYPES is None: + _HALF_TYPES = [get_accelerator().HalfTensor(0).dtype] + if _BF16_TYPES is None: + _BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype] def float_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): diff --git a/megatron/model/rmsnorm.py b/megatron/model/rmsnorm.py index 60e8978171..68b792c83d 100644 --- a/megatron/model/rmsnorm.py +++ b/megatron/model/rmsnorm.py @@ -1,16 +1,33 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +from deepspeed.accelerator import get_accelerator +from megatron import get_args + import torch +from torch.nn import init from torch.nn.parameter import Parameter + # Taken from facebookresearch/llama class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__( + self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False + ): super().__init__() self.eps = eps - self.weight = Parameter(torch.ones(dim)) + init_device = None + if get_accelerator().device_name() == "hpu": + init_device = get_accelerator().current_device_name() + self.weight = Parameter( + torch.empty(dim, device=init_device, dtype=get_args().params_dtype) + ) + init.ones_(self.weight) + self.sequence_parallel = sequence_parallel + setattr(self.weight, "sequence_parallel", self.sequence_parallel) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) - return output * self.weight \ No newline at end of file + return output * self.weight diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index 4d4497e0cd..0a7acb9efc 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # coding=utf-8 # The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \ @@ -11,6 +12,10 @@ __all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] +# sin, cos tensors cached for all devices +cos_cached = None +sin_cached = None + class RotaryEmbedding(nn.Module): def __init__(self, dim, theta=10000): super().__init__() @@ -47,10 +52,19 @@ def apply_rotary_pos_emb(t, freqs): check https://kexue.fm/archives/8265 for detailed formulas """ rot_dim = freqs.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t_pass = None + if t.shape[-1] != rot_dim: + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + global cos_cached, sin_cached + if cos_cached is None or sin_cached is None or t.shape[0] != cos_cached.shape[0]: + freqs_ = freqs[:t.shape[0]] + cos_cached = freqs_.cos().to(t.dtype) + sin_cached = freqs_.sin().to(t.dtype) # first part is cosine component # second part is sine component, need to change signs with _rotate_half method - t = (t * freqs.cos().to(t.dtype)) + (_rotate_half(t) * freqs.sin().to(t.dtype)) - return t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + t = (t * cos_cached) + (_rotate_half(t) * sin_cached) + if t_pass is None: + return t + return torch.cat((t, t_pass), dim=-1) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e75f13a24f..592ff2855b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Transformer.""" @@ -24,6 +25,7 @@ try: from deepspeed.sequence.layer import DistributedAttention + dist_attn_supported = True except ImportError: dist_attn_supported = False @@ -36,9 +38,12 @@ try: # FlashAttention (1.x) from flash_attn.flash_attn_interface import flash_attn_unpadded_func - from flash_attn.flash_attn_triton import flash_attn_func except ImportError: flash_attn_unpadded_func = None + +try: + from flash_attn.flash_attn_triton import flash_attn_func +except ImportError: flash_attn_func = None try: @@ -66,28 +71,31 @@ hyperparameters: transformer hyperparameters """ + class DropPath(MegatronModule): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ - def __init__(self, drop_prob=0.): + def __init__(self, drop_prob=0.0): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, hidden_state): - if self.drop_prob == 0. or not self.training: + if self.drop_prob == 0.0 or not self.training: return hidden_state keep_prob = 1 - self.drop_prob # work with diff dim tensors, not just 2D ConvNets # hidden_state: [s, b, h] shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) - random_tensor = keep_prob + \ - torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor = keep_prob + torch.rand( + shape, dtype=hidden_state.dtype, device=hidden_state.device + ) random_tensor.floor_() # binarize output = hidden_state.div(keep_prob) * random_tensor return output + class ParallelMLP(MegatronModule): """MLP. @@ -116,7 +124,7 @@ def __init__(self, config, moe=False, enable_expert_tensor_parallelism=False): gather_output=False, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, ) self.bias_gelu_fusion = False @@ -128,13 +136,17 @@ def __init__(self, config, moe=False, enable_expert_tensor_parallelism=False): elif args.onnx_safe: self.activation_func = erf_gelu elif args.swiglu: + def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] + self.activation_func = swiglu elif args.squared_relu: + def squared_relu(x): return torch.pow(F.relu(x), 2) + self.activation_func = squared_relu else: self.bias_gelu_fusion = args.bias_gelu_fusion @@ -149,7 +161,7 @@ def squared_relu(x): bias=self.add_bias, input_is_parallel=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, ) def forward(self, hidden_states): @@ -171,10 +183,12 @@ def forward(self, hidden_states): output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias + class SwitchMLP(MegatronModule): """ Routes input to one of N MLP "experts" """ + def __init__(self, config): super(SwitchMLP, self).__init__() args = get_args() @@ -191,29 +205,29 @@ def forward(self, hidden_states): route = self.router(hidden_states) route = torch.nn.functional.softmax(route, dim=2) max_prob, max_ind = torch.max(route, dim=2) - max_prob = torch.unsqueeze(max_prob, 2) # [s b 1] + max_prob = torch.unsqueeze(max_prob, 2) # [s b 1] # TODO (rprenger) TODO this could be made easier to read # Converting [s, b, h] to [s*b, h]. # Each vector could be routed differently - hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h] - max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1] - max_ind = max_ind.view(-1) # [s*b] + hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h] + max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1] + max_ind = max_ind.view(-1) # [s*b] output_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states) - #TODO (rprenger) This does each expert in serial, but it could be parallelized + # TODO (rprenger) This does each expert in serial, but it could be parallelized for expert_num, expert in enumerate(self.experts): local_indices = (max_ind == expert_num).nonzero() - hidden = hidden_states[local_indices,:] + hidden = hidden_states[local_indices, :] output, output_bias = expert(hidden) output_bias = output_bias.expand_as(output) - output_total[local_indices,:] = output - output_bias_total[local_indices,:] = output_bias + output_total[local_indices, :] = output + output_bias_total[local_indices, :] = output_bias - output_total = output_total*max_prob - output_bias_total = output_bias_total*max_prob + output_total = output_total * max_prob + output_bias_total = output_bias_total * max_prob output_total = output_total.view(s, b, h) output_bias_total = output_bias_total.view(s, b, h) @@ -222,8 +236,7 @@ def forward(self, hidden_states): class CoreAttention(MegatronModule): - def __init__(self, layer_number, config, - attn_mask_type=AttnMaskType.padding): + def __init__(self, layer_number, config, attn_mask_type=AttnMaskType.padding): super(CoreAttention, self).__init__() self.fp16 = config.fp16 self.bf16 = config.bf16 @@ -242,14 +255,19 @@ def __init__(self, layer_number, config, seq_parallel_world_size = 1 if parallel_state.sequence_parallel_is_initialized(): seq_parallel_world_size = parallel_state.get_sequence_parallel_world_size() - world_size = seq_parallel_world_size if seq_parallel_world_size > 1 else parallel_state.get_tensor_model_parallel_world_size() + world_size = ( + seq_parallel_world_size + if seq_parallel_world_size > 1 + else parallel_state.get_tensor_model_parallel_world_size() + ) - self.hidden_size_per_partition = core.utils.divide(projection_size, - world_size) + self.hidden_size_per_partition = core.utils.divide(projection_size, world_size) self.hidden_size_per_attention_head = core.utils.divide( - projection_size, config.num_attention_heads) + projection_size, config.num_attention_heads + ) self.num_attention_heads_per_partition = core.utils.divide( - config.num_attention_heads, world_size) + config.num_attention_heads, world_size + ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -258,49 +276,56 @@ def __init__(self, layer_number, config, self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.fp16, self.bf16, + self.fp16, + self.bf16, self.attn_mask_type, config.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, - coeff) + coeff, + ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - def forward(self, query_layer, key_layer, - value_layer, attention_mask): + def forward(self, query_layer, key_layer, value_layer, attention_mask): # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], - output_size[0] * output_size[1], -1) + query_layer = query_layer.view( + output_size[2], output_size[0] * output_size[1], -1 + ) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( - (output_size[0]*output_size[1], output_size[2], output_size[3]), - query_layer.dtype, "mpu") + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, + "mpu", + ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] + query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -310,8 +335,7 @@ def forward(self, query_layer, key_layer, # =========================== # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, - attention_mask) + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -329,18 +353,22 @@ def forward(self, query_layer, key_layer, # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -352,8 +380,9 @@ def forward(self, query_layer, key_layer, context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) context_layer = context_layer.view(*new_context_layer_shape) return context_layer @@ -369,12 +398,24 @@ class FlashSelfAttention(torch.nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - device=None, dtype=None): + + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + device=None, + dtype=None, + ): super().__init__() - assert flash_attn_unpadded_func is not None or flash_attn_varlen_func is not None or flash_attn_builder is not None, \ - ('Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention') - assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + assert ( + flash_attn_unpadded_func is not None + or flash_attn_varlen_func is not None + or flash_attn_builder is not None + ), "Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention" + assert ( + rearrange is not None + ), "Please install einops first, e.g., with pip install einops" self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout @@ -385,14 +426,18 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, self.use_flash_attn_builder_v2 = False self.use_flash_attn = False if args.use_flash_attn_builder: - if hasattr(flash_attn_builder, 'flash_attn_func'): + if hasattr(flash_attn_builder, "flash_attn_func"): self.flash_attn_func = flash_attn_builder.flash_attn_func self.use_flash_attn_builder_v1 = True else: self.flash_attn_func = flash_attn_builder.flash_attn_func_v2 self.use_flash_attn_builder_v2 = True else: - self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func + self.flash_attn_func = ( + flash_attn_varlen_func + if args.use_flash_attn_v2 + else flash_attn_unpadded_func + ) self.use_flash_attn = True def forward(self, q, k, v): @@ -402,42 +447,67 @@ def forward(self, q, k, v): q, k, v: The tensor containing the query, key, and value. (B, S, H, D) """ - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) assert all((get_accelerator().on_accelerator(i) for i in (q, k, v))) batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] if self.use_flash_attn: - q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, - device=q.device) + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q.device, + ) elif self.use_flash_attn_builder_v1: - q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]] + q, k, v = [ + rearrange(x, "b s h d -> b h s d").contiguous() for x in [q, k, v] + ] else: # use_flash_attn_builder_v2 - q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [q, k, v]] + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] if self.training: # during training q,k,v always have same seqlen assert seqlen_k == seqlen_q is_causal = self.causal - cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None + cu_seqlens_k = ( + cu_seqlens_q if get_accelerator().device_name() == "cuda" else None + ) dropout_p = self.dropout_p else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen is_causal = seqlen_q == seqlen_k - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, - device=q.device) if get_accelerator().device_name() == 'cuda' else None + cu_seqlens_k = ( + torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=q.device, + ) + if get_accelerator().device_name() == "cuda" + else None + ) dropout_p = 0 if self.use_flash_attn: output = self.flash_attn_func( - q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlen_q, + seqlen_k, dropout_p, - softmax_scale=self.softmax_scale, causal=is_causal + softmax_scale=self.softmax_scale, + causal=is_causal, ) else: # use_flash_attn_builder @@ -446,15 +516,16 @@ def forward(self, q, k, v): ) if self.use_flash_attn: - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.use_flash_attn_builder_v1: - output = rearrange(output, 'b h s d -> b s h d').contiguous() + output = rearrange(output, "b h s d -> b s h d").contiguous() else: # use_flash_attn_builder_v2: - output = rearrange(output, 'b h s d -> b s h d') + output = rearrange(output, "b h s d -> b s h d") return output + class FlashSelfAttentionTriton(torch.nn.Module): """Implement the scaled dot product attention with softmax. Arguments @@ -465,11 +536,22 @@ class FlashSelfAttentionTriton(torch.nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - device=None, dtype=None): + + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + device=None, + dtype=None, + ): super().__init__() - assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.') - assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + assert ( + flash_attn_func is not None + ), "Triton version of FlashAttention is not installed." + assert ( + rearrange is not None + ), "Please install einops first, e.g., with pip install einops" self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout @@ -483,13 +565,13 @@ def forward(self, q, k, v): assert q.dtype in [torch.float16, torch.bfloat16] assert q.is_cuda - q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (q, k, v)] - + q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] + output = flash_attn_func(q, k, v, None, self.causal) - output = rearrange(output, 'b s h d -> s b (h d)').contiguous() + output = rearrange(output, "b s h d -> s b (h d)").contiguous() return output + class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. @@ -497,9 +579,13 @@ class ParallelAttention(MegatronModule): and returns output of the same size. """ - def __init__(self, config, layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding): + def __init__( + self, + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding, + ): super(ParallelAttention, self).__init__() args = get_args() self.layer_number = max(1, layer_number) @@ -509,12 +595,18 @@ def __init__(self, config, layer_number, self.sequence_parallel = config.sequence_parallel self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.use_gqa = (self.num_attention_heads != self.num_key_value_heads) - - self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \ - args.use_flash_attn_builder) \ - and attention_type == AttnType.self_attn \ + self.use_gqa = self.num_attention_heads != self.num_key_value_heads + + self.use_flash_attn = ( + ( + args.use_flash_attn_v1 + or args.use_flash_attn_triton + or args.use_flash_attn_v2 + or args.use_flash_attn_builder + ) + and attention_type == AttnType.self_attn and self.attn_mask_type == AttnMaskType.causal + ) self.use_flash_attn_triton = args.use_flash_attn_triton if self.use_flash_attn: global flash_attn_builder @@ -524,38 +616,53 @@ def __init__(self, config, layer_number, flash_attn_builder = None if args.use_flash_attn_v1: - assert flash_attn_unpadded_func != None, "Cannot import FlashAttention v1 " + assert ( + flash_attn_unpadded_func != None + ), "Cannot import FlashAttention v1 " if args.use_flash_attn_v2: - assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 " + assert ( + flash_attn_varlen_func != None + ), "Cannot import FlashAttention v2 " if args.use_flash_attn_triton: assert flash_attn_func != None, "Cannot import FlashAttention triton " if args.use_flash_attn_builder: - assert flash_attn_builder != None, "Cannot find FlashAttention op builder " + assert ( + flash_attn_builder != None + ), "Cannot find FlashAttention op builder " - assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' - 'self-attention for now') - assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' - 'supports causal mask for now') + assert attention_type == AttnType.self_attn, ( + "FlashAttention code path only supports " "self-attention for now" + ) + assert self.attn_mask_type == AttnMaskType.causal, ( + "FlashAttention code path only " "supports causal mask for now" + ) if rearrange is None: - raise ImportError('einops is not installed, please install with pip install einops') + raise ImportError( + "einops is not installed, please install with pip install einops" + ) projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = core.utils.divide( - projection_size, config.num_attention_heads) + projection_size, config.num_attention_heads + ) self.num_attention_heads_per_partition = core.utils.divide( - config.num_attention_heads, world_size) + config.num_attention_heads, world_size + ) # Per GQA head and per partition values self.num_key_value_heads_per_partition = core.utils.divide( - config.num_key_value_heads, world_size) + config.num_key_value_heads, world_size + ) self.num_key_value_groups = core.utils.divide( - config.num_attention_heads, config.num_key_value_heads) + config.num_attention_heads, config.num_key_value_heads + ) kv_projection_size = config.kv_channels * config.num_key_value_heads assert self.hidden_size_per_attention_head == core.utils.divide( - kv_projection_size, config.num_key_value_heads) + kv_projection_size, config.num_key_value_heads + ) # Strided linear layer. if attention_type == AttnType.self_attn: @@ -565,7 +672,8 @@ def __init__(self, config, layer_number, config=config, init_method=config.init_method, bias=args.add_bias_linear, - gather_output=False) + gather_output=False, + ) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( @@ -574,8 +682,8 @@ def __init__(self, config, layer_number, config=config, init_method=config.init_method, bias=config.add_bias_linear, - gather_output=False) - + gather_output=False, + ) self.key_value = tensor_parallel.ColumnParallelLinear( config.hidden_size, @@ -583,28 +691,48 @@ def __init__(self, config, layer_number, config=config, init_method=config.init_method, bias=config.add_bias_linear, - gather_output=False) + gather_output=False, + ) # Currently FlashAttention only works with causal mask if self.use_flash_attn_triton: - local_attn = FlashSelfAttentionTriton(causal=True, attention_dropout=args.attention_dropout) + local_attn = FlashSelfAttentionTriton( + causal=True, attention_dropout=args.attention_dropout + ) elif self.use_flash_attn: - local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout) + local_attn = FlashSelfAttention( + causal=True, attention_dropout=config.attention_dropout + ) else: local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type) - self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ - or args.force_ds_sequence_parallel + self.enable_ds_sequence_parallel = ( + parallel_state.get_sequence_parallel_world_size() > 1 + or args.force_ds_sequence_parallel + ) if self.enable_ds_sequence_parallel: - assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' - assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + assert ( + dist_attn_supported + ), "Distributed attention is not supported in this DeepSpeed version" + assert ( + args.num_attention_heads + % parallel_state.get_sequence_parallel_world_size() + == 0 + ) + self.dist_attn = DistributedAttention( + local_attn, + parallel_state.get_sequence_parallel_group(), + gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0, + ) + # flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension. else: if self.use_flash_attn: self.core_attention_flash = local_attn else: self.core_attention = local_attn - self.checkpoint_core_attention = config.recompute_granularity == 'selective' + self.checkpoint_core_attention = ( + config.recompute_granularity == "selective" + ) # Output. self.dense = tensor_parallel.RowParallelLinear( @@ -614,29 +742,38 @@ def __init__(self, config, layer_number, init_method=config.output_layer_init_method, bias=args.add_bias_linear, input_is_parallel=True, - skip_bias_add=True) - + skip_bias_add=True, + ) - def _checkpointed_attention_forward(self, query_layer, key_layer, - value_layer, attention_mask, - rotary_pos_emb=None): + def _checkpointed_attention_forward( + self, query_layer, key_layer, value_layer, attention_mask, rotary_pos_emb=None + ): """Forward method with activation checkpointing.""" + def custom_forward(*inputs): query_layer = inputs[0] key_layer = inputs[1] value_layer = inputs[2] attention_mask = inputs[3] - output_ = self.core_attention(query_layer, key_layer, - value_layer, attention_mask) + output_ = self.core_attention( + query_layer, key_layer, value_layer, attention_mask + ) return output_ - q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ - else rotary_pos_emb + q_pos_emb, k_pos_emb = ( + (None, None) if rotary_pos_emb is None else rotary_pos_emb + ) hidden_states = tensor_parallel.checkpoint( custom_forward, - False, query_layer, key_layer, value_layer, attention_mask, - q_pos_emb, k_pos_emb) + False, + query_layer, + key_layer, + value_layer, + attention_mask, + q_pos_emb, + k_pos_emb, + ) return hidden_states @@ -647,28 +784,49 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, - device=get_accelerator().current_device_name()) + device=get_accelerator().current_device_name(), + ) def repeat_kv(self, hidden_states, n_rep): slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, :, None, :].expand( - slen, batch, num_key_value_heads_per_partition, n_rep, head_dim) - return hidden_states.reshape(slen, batch, - num_key_value_heads_per_partition * n_rep, - head_dim) - + elif num_key_value_heads_per_partition == 1: + # If no of KV heads is 1 then just perform expand operation + # instead of unsqueeze, expand and reshape to match query states. + return hidden_states.expand(slen, batch, n_rep, head_dim) + else: + hidden_states = hidden_states[:, :, :, None, :].expand( + slen, batch, num_key_value_heads_per_partition, n_rep, head_dim + ) + return hidden_states.reshape( + slen, batch, num_key_value_heads_per_partition * n_rep, head_dim + ) + def split_tensor(self, mixed_x_layer): - query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:2] + (-1, self.hidden_size_per_attention_head)) - key_layer = mixed_x_layer[:, :, :, -2, :] - value_layer = mixed_x_layer[:, :, :, -1, :] + query_layer, key_layer, value_layer = torch.split( + mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2 + ) + query_layer = query_layer.reshape( + mixed_x_layer.shape[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = torch.squeeze(key_layer, -2) + value_layer = torch.squeeze(value_layer, -2) return query_layer, key_layer, value_layer - def forward(self, hidden_states, attention_mask, - encoder_output=None, inference_params=None, - rotary_pos_emb=None): + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + inference_params=None, + rotary_pos_emb=None, + ): # hidden_states: [sq, b, h] # ================================================= @@ -680,15 +838,20 @@ def forward(self, hidden_states, attention_mask, inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size) + inf_max_seq_len, inf_max_batch_size + ) inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size) + inf_max_seq_len, inf_max_batch_size + ) inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, inference_value_memory) + inference_key_memory, + inference_value_memory, + ) is_first_step = True else: - inference_key_memory, inference_value_memory = \ + inference_key_memory, inference_value_memory = ( inference_params.key_value_memory_dict[self.layer_number] + ) # ===================== # Query, Key, and Value @@ -699,43 +862,45 @@ def forward(self, hidden_states, attention_mask, mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (-1, (self.num_key_value_groups + 2), - self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + -1, + (self.num_key_value_groups + 2), + self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] - (query_layer, - key_layer, - value_layer) = self.split_tensor(mixed_x_layer) + (query_layer, key_layer, value_layer) = self.split_tensor(mixed_x_layer) # Repeat kv if self.use_gqa: key_layer = self.repeat_kv(key_layer, self.num_key_value_groups) - value_layer = self.repeat_kv(value_layer, - self.num_key_value_groups) + value_layer = self.repeat_kv(value_layer, self.num_key_value_groups) else: - assert not self.use_gqa, 'GQA + cross-attn not tested yet' + assert not self.use_gqa, "GQA + cross-attn not tested yet" # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, - value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_kv_layer, 2 + ) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_tensor_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_tensor_shape) # ================================== @@ -747,7 +912,7 @@ def forward(self, hidden_states, attention_mask, if isinstance(rotary_pos_emb, tuple): rotary_pos_emb = rotary_pos_emb else: - rotary_pos_emb = ((rotary_pos_emb,) * 2) + rotary_pos_emb = (rotary_pos_emb,) * 2 if inference_params: batch_start = inference_params.batch_size_offset @@ -757,15 +922,16 @@ def forward(self, hidden_states, attention_mask, sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. - inference_key_memory[sequence_start:sequence_end, - batch_start:batch_end, ...] = key_layer - inference_value_memory[sequence_start:sequence_end, - batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[ - :sequence_end, batch_start:batch_end, ...] + inference_key_memory[ + sequence_start:sequence_end, batch_start:batch_end, ... + ] = key_layer + inference_value_memory[ + sequence_start:sequence_end, batch_start:batch_end, ... + ] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[ - :sequence_end, batch_start:batch_end, ...] - + :sequence_end, batch_start:batch_end, ... + ] # adjust the key rotary positional embedding if rotary_pos_emb is not None: @@ -787,7 +953,6 @@ def forward(self, hidden_states, attention_mask, k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) - # ================================== # core attention computation # ================================== @@ -803,38 +968,58 @@ def forward(self, hidden_states, attention_mask, # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if self.enable_ds_sequence_parallel: + batch_dim_idx = 1 if self.use_flash_attn: if not self.use_flash_attn_triton: - query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (query_layer, key_layer, value_layer)] - - context_layer = self.dist_attn(query_layer, key_layer, value_layer) + query_layer, key_layer, value_layer = [ + rearrange(x, "s b ... -> b s ...").contiguous() + for x in (query_layer, key_layer, value_layer) + ] + batch_dim_idx = 0 + + context_layer = self.dist_attn( + query_layer, key_layer, value_layer, batch_dim_idx + ) if not self.use_flash_attn_triton: - context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() else: - context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask) + context_layer = self.dist_attn( + query_layer, key_layer, value_layer, attention_mask + ) else: if self.use_flash_attn: if not self.use_flash_attn_triton: - query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (query_layer, key_layer, value_layer)] + query_layer, key_layer, value_layer = [ + rearrange(x, "s b ... -> b s ...").contiguous() + for x in (query_layer, key_layer, value_layer) + ] if self.sequence_parallel: - context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) + context_layer = self.core_attention_flash( + query_layer, key_layer, value_layer + ) else: with tensor_parallel.get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) + context_layer = self.core_attention_flash( + query_layer, key_layer, value_layer + ) if not self.use_flash_attn_triton: - context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() else: if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) + query_layer, key_layer, value_layer, attention_mask + ) else: context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) + query_layer, key_layer, value_layer, attention_mask + ) # ================= # Output. [sq, b, h] @@ -857,22 +1042,21 @@ def bias_dropout_add(x, bias, residual, prob, training): def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add @torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: Optional[torch.Tensor], - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_train( + x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: Optional[torch.Tensor], - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_inference( + x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) @@ -883,10 +1067,15 @@ class ParallelTransformerLayer(MegatronModule): output of the same size. """ - def __init__(self, config, - layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - drop_path_rate=0., num_experts=1): + def __init__( + self, + config, + layer_number, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.0, + num_experts=1, + ): # retriever=None): args = get_args() @@ -894,106 +1083,130 @@ def __init__(self, config, self.layer_number = layer_number self.layer_type = layer_type - self.apply_residual_connection_post_layernorm \ - = config.apply_residual_connection_post_layernorm + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) self.bf16 = config.bf16 self.fp32_residual_connection = config.fp32_residual_connection # Layernorm on the input data. - if args.normalization == 'layernorm': - if get_accelerator().device_name() == 'cuda': + if args.normalization == "layernorm": + if get_accelerator().device_name() == "cuda": self.input_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=config.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p, - mem_efficient_ln=args.mem_efficient_ln) + mem_efficient_ln=args.mem_efficient_ln, + ) else: self.input_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) else: - self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.input_layernorm = RMSNorm( + config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel, + ) + # self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon_ # Self attention. self.self_attention = ParallelAttention( config, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + attn_mask_type=self_attn_mask_type, + ) self.hidden_dropout = config.hidden_dropout self.bias_dropout_fusion = config.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None # Layernorm on the attention output - if args.normalization == 'layernorm': - if get_accelerator().device_name() == 'cuda': + if args.normalization == "layernorm": + if get_accelerator().device_name() == "cuda": self.post_attention_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=not config.persist_layer_norm, sequence_parallel=config.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p, - mem_efficient_ln=args.mem_efficient_ln) + mem_efficient_ln=args.mem_efficient_ln, + ) else: self.post_attention_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) else: - self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel, + ) # Cross attention. - if self.layer_type in (LayerType.decoder, - LayerType.retro_decoder, - LayerType.retro_decoder_with_retriever, - LayerType.retro_encoder): + if self.layer_type in ( + LayerType.decoder, + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + LayerType.retro_encoder, + ): self.inter_attention = ParallelAttention( - config, - layer_number, - attention_type=AttnType.cross_attn) + config, layer_number, attention_type=AttnType.cross_attn + ) # Layernorm on the attention output. - if args.normalization == 'layernorm': + if args.normalization == "layernorm": self.post_inter_attention_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=not config.persist_layer_norm, sequence_parallel=config.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p, - mem_efficient_ln=args.mem_efficient_ln) + mem_efficient_ln=args.mem_efficient_ln, + ) else: - self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_inter_attention_layernorm = RMSNorm( + config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel, + ) # MLP self.num_experts = num_experts if args.num_experts_switch is not None: - self.mlp = SwitchMLP(config) # Megatron-LM's MoE + self.mlp = SwitchMLP(config) # Megatron-LM's MoE else: - if self.num_experts <= 1: # dense, not MoE + if self.num_experts <= 1: # dense, not MoE self.mlp = ParallelMLP(config) - else: # DeepSpeed's MoE + else: # DeepSpeed's MoE enable_expert_tensor_parallelism = args.enable_expert_tensor_parallelism - self.mlp = MoE(args.hidden_size, - ParallelMLP(config, - moe=True, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.topk, - use_residual=(args.mlp_type == 'residual'), - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, - top2_2nd_expert_sampling=args.moe_top2_2nd_expert_sampling) + self.mlp = MoE( + args.hidden_size, + ParallelMLP( + config, + moe=True, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.topk, + use_residual=(args.mlp_type == "residual"), + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + top2_2nd_expert_sampling=args.moe_top2_2nd_expert_sampling, + ) # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = \ - nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = ( + nullcontext if use_nvfuser else torch.enable_grad + ) if args.retro_add_retriever: retro_args = get_retro_args() @@ -1011,23 +1224,24 @@ def __init__(self, config, pre_process=True, post_process=False, ) - self._retriever_key = 'retriever' + self._retriever_key = "retriever" else: self.retriever = None - def default_decoder_cross_attention(self, - encoder_output, - enc_dec_attn_mask, - layernorm_input, - layernorm_output, - bias_dropout_add_func): - '''Cross attention for a standard encoder-decoder model.''' + def default_decoder_cross_attention( + self, + encoder_output, + enc_dec_attn_mask, + layernorm_input, + layernorm_output, + bias_dropout_add_func, + ): + """Cross attention for a standard encoder-decoder model.""" # Attention. - attention_output, attention_bias = \ - self.inter_attention(layernorm_output, - enc_dec_attn_mask, - encoder_output=encoder_output) + attention_output, attention_bias = self.inter_attention( + layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -1041,21 +1255,17 @@ def default_decoder_cross_attention(self, # Bias-dropout-add. with self.bias_dropout_add_exec_handler(): layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias, - residual, - self.hidden_dropout) + attention_output, attention_bias, residual, self.hidden_dropout + ) # Layer norm. layernorm_output = self.post_inter_attention_layernorm(layernorm_input) return layernorm_input, layernorm_output - def retro_encoder_cross_attention(self, - retriever_output, - layernorm_input, - layernorm_output, - bias_dropout_add_func): + def retro_encoder_cross_attention( + self, retriever_output, layernorm_input, layernorm_output, bias_dropout_add_func + ): """Cross attention for Retro encoder. Notation: @@ -1067,16 +1277,15 @@ def retro_encoder_cross_attention(self, r : Number of retrieved tokens (neighbors + continuation). """ - ns, bs, d = layernorm_output.shape # [r, bs * l * k, d] + ns, bs, d = layernorm_output.shape # [r, bs * l * k, d] # Divide sequence dimension into chunks. - chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length, - -1, - self.retro_num_neighbors, - d) - chunked_outputs_before_layer_norm = \ - layernorm_input.reshape(self.retro_retrieved_length, -1, - self.retro_num_neighbors, d) # [r, bs*l, k, d] + chunked_outputs = layernorm_output.reshape( + self.retro_retrieved_length, -1, self.retro_num_neighbors, d + ) + chunked_outputs_before_layer_norm = layernorm_input.reshape( + self.retro_retrieved_length, -1, self.retro_num_neighbors, d + ) # [r, bs*l, k, d] # Per-chunk attention. layernorm_inputs = [] @@ -1084,51 +1293,55 @@ def retro_encoder_cross_attention(self, for k in range(self.retro_num_neighbors): # Attention. - chunked_output = chunked_outputs[:,:,k].contiguous() - attention_output, attention_bias = \ - self.inter_attention( - chunked_output, # Q (neighbor embedding) - None, - encoder_output=retriever_output) # K, V (hidden act) + chunked_output = chunked_outputs[:, :, k].contiguous() + attention_output, attention_bias = self.inter_attention( + chunked_output, # Q (neighbor embedding) + None, + encoder_output=retriever_output, + ) # K, V (hidden act) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = chunked_output else: - residual = chunked_outputs_before_layer_norm[:,:,k] + residual = chunked_outputs_before_layer_norm[:, :, k] # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, - None if attention_bias is None else attention_bias.expand_as(residual), + ( + None + if attention_bias is None + else attention_bias.expand_as(residual) + ), residual, - self.hidden_dropout) + self.hidden_dropout, + ) layernorm_inputs.append(layernorm_input) # Layer norm. - layernorm_output = \ - self.post_inter_attention_layernorm(layernorm_input) + layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_outputs.append(layernorm_output) # Concatenate layer norms. # layernorm_input : [r, k * bs * l, d] # layernorm_output : [r, k * bs * l, d] - layernorm_input = \ - torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d) - layernorm_output = \ - torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d) + layernorm_input = torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d) + layernorm_output = torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d) return layernorm_input, layernorm_output - def retro_decoder_cross_attention(self, - retriever_input, - retriever_output, - retriever_attn_mask, - layernorm_input, - layernorm_output, - inference_params, - bias_dropout_add_func): + def retro_decoder_cross_attention( + self, + retriever_input, + retriever_output, + retriever_attn_mask, + layernorm_input, + layernorm_output, + inference_params, + bias_dropout_add_func, + ): """Cross attention for Retro decoder. Notation: @@ -1149,22 +1362,27 @@ def retro_decoder_cross_attention(self, first_ns = ns % self.retro_chunk_length if first_ns > 0: raise Exception("test this case.") - first_chunk, rest_chunk = \ - layernorm_output[:first_ns], layernorm_output[first_ns:] + first_chunk, rest_chunk = ( + layernorm_output[:first_ns], + layernorm_output[first_ns:], + ) first_chunk = torch.nn.functional.pad( first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), - 'constant', - 0) - chunked_output = \ - torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] + "constant", + 0, + ) + chunked_output = torch.cat( + (first_chunk, rest_chunk), dim=0 + ) # [l * m, bs, d] else: - chunked_output = layernorm_output # [l * m, bs, d] - chunked_output = chunked_output \ - .reshape(l, self.retro_chunk_length, bs, d) \ - .permute(1, 2, 0, 3) \ - .reshape(self.retro_chunk_length, bs * l, d) \ + chunked_output = layernorm_output # [l * m, bs, d] + chunked_output = ( + chunked_output.reshape(l, self.retro_chunk_length, bs, d) + .permute(1, 2, 0, 3) + .reshape(self.retro_chunk_length, bs * l, d) .contiguous() + ) # Get Encoder Output retriever_output = self.retriever( @@ -1172,9 +1390,11 @@ def retro_decoder_cross_attention(self, attention_mask=retriever_attn_mask, retriever_output=chunked_output, retriever_attn_mask=retriever_attn_mask, - inference_params=inference_params) # [r, k * bs * l , d] + inference_params=inference_params, + ) # [r, k * bs * l , d] retriever_output = retriever_output.reshape( - self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] + self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d + ) # [r * k, bs * l, d] # Chunks. pad = (ns - 1) % self.retro_chunk_length @@ -1182,18 +1402,20 @@ def retro_decoder_cross_attention(self, padded_chunks = torch.nn.functional.pad( attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), - 'constant', 0) - padded_chunked_output = padded_chunks \ - .reshape(l, self.retro_chunk_length, bs, d) \ - .permute(1, 2, 0, 3) + "constant", + 0, + ) + padded_chunked_output = padded_chunks.reshape( + l, self.retro_chunk_length, bs, d + ).permute(1, 2, 0, 3) padded_chunked_output = padded_chunked_output.reshape( - self.retro_chunk_length, bs * l, d).contiguous() + self.retro_chunk_length, bs * l, d + ).contiguous() # Encoder output. - attention_output, attention_bias = \ - self.inter_attention(padded_chunked_output, - None, - encoder_output=retriever_output) + attention_output, attention_bias = self.inter_attention( + padded_chunked_output, None, encoder_output=retriever_output + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -1205,17 +1427,27 @@ def retro_decoder_cross_attention(self, with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, - None if attention_bias is None else attention_bias.expand_as(attention_output), + ( + None + if attention_bias is None + else attention_bias.expand_as(attention_output) + ), torch.zeros_like(attention_output), - self.hidden_dropout) - layernorm_input = layernorm_input \ - .reshape(self.retro_chunk_length, bs, l, d) \ - .permute(2, 0, 1, 3) # [l, m, bs, d] - layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d) + self.hidden_dropout, + ) + layernorm_input = layernorm_input.reshape( + self.retro_chunk_length, bs, l, d + ).permute( + 2, 0, 1, 3 + ) # [l, m, bs, d] + layernorm_input = layernorm_input.reshape( + self.retro_chunk_length * l, bs, d + ) layernorm_input = torch.nn.functional.pad( - layernorm_input, - (0, 0, 0, 0, pad, 0), - 'constant', 0)[:ns] # [ns, b, d] + layernorm_input, (0, 0, 0, 0, pad, 0), "constant", 0 + )[ + :ns + ] # [ns, b, d] layernorm_input = layernorm_input + residual # Layer norm post the decoder attention @@ -1223,26 +1455,31 @@ def retro_decoder_cross_attention(self, return retriever_output, layernorm_input, layernorm_output - def forward(self, hidden_states, attention_mask=None, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None, - aggregated_moe_loss=None): + def forward( + self, + hidden_states, + attention_mask=None, + encoder_output=None, + enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + aggregated_moe_loss=None, + ): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output, attention_bias = \ - self.self_attention( - layernorm_output, - attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) + attention_output, attention_bias = self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -1267,14 +1504,14 @@ def forward(self, hidden_states, attention_mask=None, attention_bias = attention_bias.expand_as(residual) with self.bias_dropout_add_exec_handler(): layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias, - residual, - self.hidden_dropout) + attention_output, attention_bias, residual, self.hidden_dropout + ) else: - out = torch.nn.functional.dropout(attention_output + attention_bias, - p=self.hidden_dropout, - training=self.training) + out = torch.nn.functional.dropout( + attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training, + ) layernorm_input = residual + self.drop_path(out) # Layer norm post the self attention. @@ -1284,23 +1521,25 @@ def forward(self, hidden_states, attention_mask=None, if self.layer_type == LayerType.encoder: pass elif self.layer_type == LayerType.decoder: - layernorm_input, layernorm_output = \ - self.default_decoder_cross_attention( - encoder_output, - enc_dec_attn_mask, - layernorm_input, - layernorm_output, - bias_dropout_add_func) + layernorm_input, layernorm_output = self.default_decoder_cross_attention( + encoder_output, + enc_dec_attn_mask, + layernorm_input, + layernorm_output, + bias_dropout_add_func, + ) elif self.layer_type == LayerType.retro_encoder: - layernorm_input, layernorm_output = \ - self.retro_encoder_cross_attention( - retriever_output, - layernorm_input, - layernorm_output, - bias_dropout_add_func) - elif self.layer_type in (LayerType.retro_decoder, - LayerType.retro_decoder_with_retriever): - retriever_output, layernorm_input, layernorm_output = \ + layernorm_input, layernorm_output = self.retro_encoder_cross_attention( + retriever_output, + layernorm_input, + layernorm_output, + bias_dropout_add_func, + ) + elif self.layer_type in ( + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + ): + retriever_output, layernorm_input, layernorm_output = ( self.retro_decoder_cross_attention( retriever_input, retriever_output, @@ -1308,14 +1547,19 @@ def forward(self, hidden_states, attention_mask=None, layernorm_input, layernorm_output, inference_params, - bias_dropout_add_func) + bias_dropout_add_func, + ) + ) else: - raise Exception("Unsupported layer type, '%s'." % - self.layer_type.name) + raise Exception("Unsupported layer type, '%s'." % self.layer_type.name) # MLP. - moe_loss = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) - mlp_bias = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) + moe_loss = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) if self.num_experts == 1: mlp_output, mlp_bias = self.mlp(layernorm_output) @@ -1337,10 +1581,8 @@ def forward(self, hidden_states, attention_mask=None, mlp_bias = mlp_bias.expand_as(residual) with self.bias_dropout_add_exec_handler(): output = bias_dropout_add_func( - mlp_output, - mlp_bias, - residual, - self.hidden_dropout) + mlp_output, mlp_bias, residual, self.hidden_dropout + ) # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, @@ -1348,16 +1590,16 @@ def forward(self, hidden_states, attention_mask=None, # won't result in memory savings (like the data loader, or # p2p_communication), it serves to document the origin of this # 'view' tensor. - output = core.utils.make_viewless_tensor(inp = output, - requires_grad = output.requires_grad, - keep_graph = True) + output = core.utils.make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True + ) else: if mlp_bias is not None: mlp_output = mlp_output + mlp_bias - out = torch.nn.functional.dropout(mlp_output, - p=self.hidden_dropout, - training=self.training) + out = torch.nn.functional.dropout( + mlp_output, p=self.hidden_dropout, training=self.training + ) output = residual + self.drop_path(out) if self.layer_type == LayerType.retro_decoder_with_retriever: @@ -1386,25 +1628,47 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer): If no mask is provided, the module will query `self._args.attn_mask` for the mask and only return `super().forward(...)` """ - def __init__(self, config, - layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - drop_path_rate=0., num_experts=1, - input_aggregated_moe_loss=False, return_aggregated_moe_loss=False): + + def __init__( + self, + config, + layer_number, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.0, + num_experts=1, + input_aggregated_moe_loss=False, + return_aggregated_moe_loss=False, + ): self.input_aggregated_moe_loss = input_aggregated_moe_loss self.return_aggregated_moe_loss = return_aggregated_moe_loss - super().__init__(config, layer_number, layer_type, self_attn_mask_type, drop_path_rate, num_experts) + super().__init__( + config, + layer_number, + layer_type, + self_attn_mask_type, + drop_path_rate, + num_experts, + ) def forward(self, inputs, **kwargs): assert torch.is_tensor(inputs) or isinstance(inputs, tuple) - if not hasattr(self, '_args'): + if not hasattr(self, "_args"): self._args = get_args() - rotary_pos_emb = self._args.rotary_pos_emb if self._args.use_rotary_position_embeddings else None + rotary_pos_emb = ( + self._args.rotary_pos_emb + if self._args.use_rotary_position_embeddings + else None + ) if torch.is_tensor(inputs) or len(inputs) == 1: - assert not self.input_aggregated_moe_loss, f'Expecting an input tuple of size >= 2' + assert ( + not self.input_aggregated_moe_loss + ), f"Expecting an input tuple of size >= 2" # No attention mask forwarded, search for args.attn_mask hidden_states, attention_mask = inputs, self._args.attn_mask - output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb) + output, moe_loss = super().forward( + hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb + ) return (output, moe_loss) if self.return_aggregated_moe_loss else output elif len(inputs) in (2, 3): # Attention mask and aggregated_moe can both be activations. @@ -1417,22 +1681,28 @@ def forward(self, inputs, **kwargs): hidden_states, attention_mask = inputs[0], inputs[1] return_attention_mask = True else: - hidden_states, attention_mask, aggregated_moe_loss = inputs[0], inputs[1], inputs[2] + hidden_states, attention_mask, aggregated_moe_loss = ( + inputs[0], + inputs[1], + inputs[2], + ) # Forward aggregated_moe_loss to ParallelTransformerLayer for further accumulation if self.input_aggregated_moe_loss: - kwargs.update({'aggregated_moe_loss': aggregated_moe_loss}) + kwargs.update({"aggregated_moe_loss": aggregated_moe_loss}) - output, moe_loss = super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb) + output, moe_loss = super().forward( + hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb + ) - ret = (output, ) + ret = (output,) if return_attention_mask: - ret += (attention_mask, ) + ret += (attention_mask,) if self.return_aggregated_moe_loss: - ret += (moe_loss, ) + ret += (moe_loss,) return ret else: - raise RuntimeError('Received more inputs than understood.') + raise RuntimeError("Received more inputs than understood.") class NoopTransformerLayer(MegatronModule): @@ -1455,15 +1725,20 @@ def __init__(self, layer_number): super().__init__() self.layer_number = layer_number - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - inference_params=None): + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None, + ): return hidden_states.clone() def _get_num_layers(args, model_type, is_decoder=False): """Compute the number of transformer layers resident on the current rank.""" - is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) + is_encoder_and_decoder_model = model_type == ModelType.encoder_and_decoder if model_type == ModelType.retro_encoder: num_layers = args.retro_encoder_layers elif parallel_state.get_pipeline_model_parallel_world_size() > 1: @@ -1476,27 +1751,34 @@ def _get_num_layers(args, model_type, is_decoder=False): # the same whether or not a standalone embedding stage is used. num_ranks_in_encoder = ( args.pipeline_model_parallel_split_rank - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_split_rank + if args.standalone_embedding_stage + else args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = ( + args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + ) + assert args.encoder_num_layers % num_ranks_in_encoder == 0, ( + "encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)" + % (args.encoder_num_layers, num_ranks_in_encoder) + ) + assert args.decoder_num_layers % num_ranks_in_decoder == 0, ( + "decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)" + % (args.decoder_num_layers, num_ranks_in_decoder) ) - num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder - assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ - 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) - assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ - 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) if parallel_state.is_pipeline_stage_before_split(): num_layers = ( 0 if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 else - args.encoder_num_layers // num_ranks_in_encoder + and parallel_state.get_pipeline_model_parallel_rank() == 0 + else args.encoder_num_layers // num_ranks_in_encoder ) else: num_layers = args.decoder_num_layers // num_ranks_in_decoder else: assert args.num_layers == args.encoder_num_layers - assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ - 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + assert ( + args.num_layers % args.transformer_pipeline_model_parallel_size == 0 + ), "num_layers must be divisible by transformer_pipeline_model_parallel_size" # When a standalone embedding stage is used, all transformer layers # are divided among pipeline rank >= 1, while on pipeline rank 0, @@ -1505,8 +1787,8 @@ def _get_num_layers(args, model_type, is_decoder=False): num_layers = ( 0 if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 else - args.num_layers // args.transformer_pipeline_model_parallel_size + and parallel_state.get_pipeline_model_parallel_rank() == 0 + else args.num_layers // args.transformer_pipeline_model_parallel_size ) else: if not is_decoder: @@ -1516,14 +1798,15 @@ def _get_num_layers(args, model_type, is_decoder=False): return num_layers -def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, - layer_number): +def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, layer_number): args = get_args() if args.retro_add_retriever and layer_number in retro_layer_numbers: if model_type == ModelType.retro_decoder: - return LayerType.retro_decoder_with_retriever \ - if layer_number == retro_layer_numbers[0] \ - else LayerType.retro_decoder + return ( + LayerType.retro_decoder_with_retriever + if layer_number == retro_layer_numbers[0] + else LayerType.retro_decoder + ) elif model_type == ModelType.retro_encoder: return LayerType.retro_encoder else: @@ -1532,15 +1815,22 @@ def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, return default_layer_type -def get_num_experts_per_layer(num_experts: list, num_layers: int, expert_interval: int, offset: int = 0) -> list: - assert len(num_experts) == 1 or len(num_experts) == num_layers // expert_interval, \ - 'num_experts must be either a single value or a list of the same length as the number of MoE layers' +def get_num_experts_per_layer( + num_experts: list, num_layers: int, expert_interval: int, offset: int = 0 +) -> list: + assert ( + len(num_experts) == 1 or len(num_experts) == num_layers // expert_interval + ), "num_experts must be either a single value or a list of the same length as the number of MoE layers" if len(num_experts) == 1: num_experts = num_experts * (num_layers // expert_interval) experts_per_layer = [] for i in range(num_layers): layer_num = i + 1 + offset - n_e = num_experts[(layer_num-1) // expert_interval] if layer_num % expert_interval == 0 else 1 + n_e = ( + num_experts[(layer_num - 1) // expert_interval] + if layer_num % expert_interval == 0 + else 1 + ) experts_per_layer.append(n_e) return experts_per_layer @@ -1548,14 +1838,18 @@ def get_num_experts_per_layer(num_experts: list, num_layers: int, expert_interva class ParallelTransformer(MegatronModule): """Transformer class.""" - def __init__(self, config, - model_type, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - post_layer_norm=True, - pre_process=True, - post_process=True, - drop_path_rate=0.0, - num_experts=[1]): + def __init__( + self, + config, + model_type, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + post_layer_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0, + num_experts=[1], + ): super(ParallelTransformer, self).__init__() args = get_args() @@ -1578,14 +1872,15 @@ def __init__(self, config, self.recompute_granularity = config.recompute_granularity self.recompute_method = config.recompute_method self.recompute_num_layers = config.recompute_num_layers - self.distribute_saved_activations = \ + self.distribute_saved_activations = ( config.distribute_saved_activations and not config.sequence_parallel + ) self.sequence_parallel = config.sequence_parallel # Transformer Engine Init. self.transformer_engine_rope_available = False - if self.transformer_impl == 'transformer_engine': + if self.transformer_impl == "transformer_engine": global transformer_engine import transformer_engine from importlib.metadata import version @@ -1617,45 +1912,53 @@ def __init__(self, config, self.num_microbatches_in_previous_step = -1 self.microbatch_count = 0 - self.checkpoint_core_attention = config.recompute_granularity == 'selective' + self.checkpoint_core_attention = config.recompute_granularity == "selective" # Number of layers. - self.num_layers = _get_num_layers(args, model_type, - layer_type==LayerType.decoder) + self.num_layers = _get_num_layers( + args, model_type, layer_type == LayerType.decoder + ) self.drop_path_rates = [ - rate.item() for rate in - torch.linspace(0, self.drop_path_rate, config.num_layers)] + rate.item() + for rate in torch.linspace(0, self.drop_path_rate, config.num_layers) + ] self.retro_layer_numbers = None if model_type == ModelType.retro_decoder: retro_layer_start = 6 if config.num_layers <= 15 else 9 - self.retro_layer_numbers = \ - np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() + self.retro_layer_numbers = np.arange( + retro_layer_start, args.num_layers + 1, 3 + ).tolist() if model_type == ModelType.retro_encoder: self.retro_layer_numbers = [1] # Transformer layers. if args.retro_add_retriever: - assert self.recompute_granularity != 'full', \ - "Full recompute not supported for Retro." - assert args.transformer_impl == 'local', \ - "Transformer engine does not support Retro layers." + assert ( + self.recompute_granularity != "full" + ), "Full recompute not supported for Retro." + assert ( + args.transformer_impl == "local" + ), "Transformer engine does not support Retro layers." + def build_layer(layer_number, n_e): - if args.transformer_impl == 'local': + if args.transformer_impl == "local": current_layer_type = _get_layer_type( - model_type, layer_type, self.retro_layer_numbers, - layer_number) + model_type, layer_type, self.retro_layer_numbers, layer_number + ) return ParallelTransformerLayer( config, layer_number, layer_type=current_layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1], - num_experts=n_e) + num_experts=n_e, + ) else: - assert config.num_attention_heads == config.num_key_value_heads, \ - 'Transformer_engine does not support GQA' + assert ( + config.num_attention_heads == config.num_key_value_heads + ), "Transformer_engine does not support GQA" return transformer_engine.pytorch.TransformerLayer( config.hidden_size, config.ffn_hidden_size, @@ -1682,16 +1985,22 @@ def build_layer(layer_number, n_e): layer_type="encoder", drop_path_rate=self.drop_path_rates[layer_number - 1], set_parallel_mode=True, - fuse_qkv_params=True) + fuse_qkv_params=True, + ) if config.virtual_pipeline_model_parallel_size is not None: - assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ - 'num_layers_per_stage must be divisible by ' \ - 'virtual_pipeline_model_parallel_size' + assert ( + config.num_layers % config.virtual_pipeline_model_parallel_size == 0 + ), ( + "num_layers_per_stage must be divisible by " + "virtual_pipeline_model_parallel_size" + ) assert args.model_type != ModelType.encoder_and_decoder # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. - self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + self.num_layers = ( + self.num_layers // config.virtual_pipeline_model_parallel_size + ) # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] @@ -1701,12 +2010,14 @@ def build_layer(layer_number, n_e): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( - config.num_layers // config.virtual_pipeline_model_parallel_size) + \ - (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) + config.num_layers // config.virtual_pipeline_model_parallel_size + ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. - if args.model_type == ModelType.encoder_and_decoder and \ - parallel_state.get_pipeline_model_parallel_world_size() > 1: + if ( + args.model_type == ModelType.encoder_and_decoder + and parallel_state.get_pipeline_model_parallel_world_size() > 1 + ): pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers @@ -1714,7 +2025,9 @@ def build_layer(layer_number, n_e): num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: - offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers + offset = ( + parallel_state.get_pipeline_model_parallel_rank() * self.num_layers + ) if self.num_layers == 0: # When a standalone embedding stage is used (e.g., @@ -1726,11 +2039,13 @@ def build_layer(layer_number, n_e): # this, we assign a 'no-op' layer on these ranks, which will # disconnect the input tensor from the output tensor. self.num_layers = 1 - self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) else: # Build the layers self.layers = [] - experts_per_layer = get_num_experts_per_layer(num_experts, self.num_layers, args.expert_interval, offset) + experts_per_layer = get_num_experts_per_layer( + num_experts, self.num_layers, args.expert_interval, offset + ) for i in range(self.num_layers): layer_num = i + 1 + offset n_e = experts_per_layer[i] @@ -1741,40 +2056,54 @@ def build_layer(layer_number, n_e): if model_type == ModelType.retro_encoder: for layer in self.layers: if layer.self_attention.use_flash_attn: - layer.self_attention.core_attention_flash.dropout_p = \ + layer.self_attention.core_attention_flash.dropout_p = ( torch.nn.Dropout(args.retro_encoder_attention_dropout) + ) else: - layer.self_attention.core_attention.attention_dropout.p =\ + layer.self_attention.core_attention.attention_dropout.p = ( args.retro_encoder_attention_dropout + ) layer.hidden_dropout = args.retro_encoder_hidden_dropout if self.post_process and self.post_layer_norm: # Final layer norm before output. - if args.normalization == 'layernorm': - if get_accelerator().device_name() == 'cuda': + if args.normalization == "layernorm": + if get_accelerator().device_name() == "cuda": self.final_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=config.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p, - mem_efficient_ln=args.mem_efficient_ln) + mem_efficient_ln=args.mem_efficient_ln, + ) else: self.final_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) else: - self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.final_layernorm = RMSNorm( + config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel, + ) def _get_layer(self, layer_number): return self.layers[layer_number] - def _checkpointed_forward(self, hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - rotary_pos_emb, is_first_microbatch): + def _checkpointed_forward( + self, + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch, + ): args = get_args() """Forward method with activation checkpointing.""" + def custom(start, end): def custom_forward(*args, **kwargs): x_, *args = args @@ -1786,11 +2115,14 @@ def custom_forward(*args, **kwargs): x_, moe_loss = output else: x_ = output - moe_loss = torch.tensor(0.0, device=x_.device, dtype=x_.dtype, requires_grad=True) + moe_loss = torch.tensor( + 0.0, device=x_.device, dtype=x_.dtype, requires_grad=True + ) moe_losses.append(moe_loss) return (x_, *moe_losses) + return custom_forward - + if args.deepspeed and args.deepspeed_activation_checkpointing: moe_losses = [] # Make sure memory is freed. @@ -1798,9 +2130,18 @@ def custom_forward(*args, **kwargs): l = 0 while l < self.num_layers: hidden_states, *local_moe_losses = tensor_parallel.checkpoint( - custom(l, l + self.checkpoint_num_layers), False, - hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) + custom(l, l + self.checkpoint_num_layers), + False, + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + None, + None, + None, + None, + rotary_pos_emb, + ) moe_losses.extend(local_moe_losses) l += self.checkpoint_num_layers @@ -1808,66 +2149,105 @@ def custom_forward(*args, **kwargs): else: moe_losses = [] te_forward_kwargs = {} - if self.transformer_impl == 'transformer_engine': - te_forward_kwargs['is_first_microbatch'] = is_first_microbatch + if self.transformer_impl == "transformer_engine": + te_forward_kwargs["is_first_microbatch"] = is_first_microbatch if self.transformer_engine_rope_available: - te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + te_forward_kwargs["rotary_pos_emb"] = rotary_pos_emb - if self.recompute_method == 'uniform': + if self.recompute_method == "uniform": # Uniformly divide the total number of Transformer layers and # checkpoint the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: - if self.transformer_impl == 'transformer_engine': - hidden_states, *local_moe_losses = transformer_engine.pytorch.distributed.checkpoint( - custom(l, l + self.recompute_num_layers), - self.distribute_saved_activations, - tensor_parallel.get_cuda_rng_tracker, - mpu.get_tensor_model_parallel_group(), - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) + if self.transformer_impl == "transformer_engine": + hidden_states, *local_moe_losses = ( + transformer_engine.pytorch.distributed.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + **te_forward_kwargs, + ) + ) else: hidden_states, *local_moe_losses = tensor_parallel.checkpoint( custom(l, l + self.recompute_num_layers), self.distribute_saved_activations, - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + None, + None, + None, + None, + rotary_pos_emb, + ) moe_losses.extend(local_moe_losses) l += self.recompute_num_layers - elif self.recompute_method == 'block': + elif self.recompute_method == "block": # Checkpoint the input activation of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory removing redundant re-computation. for l in range(self.num_layers): if l < self.recompute_num_layers: - if self.transformer_impl == 'transformer_engine': - hidden_states, *local_moe_losses = transformer_engine.pytorch.distributed.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - tensor_parallel.get_cuda_rng_tracker, - mpu.get_tensor_model_parallel_group(), - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) + if self.transformer_impl == "transformer_engine": + hidden_states, *local_moe_losses = ( + transformer_engine.pytorch.distributed.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + **te_forward_kwargs, + ) + ) else: - hidden_states, *local_moe_losses = tensor_parallel.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) + hidden_states, *local_moe_losses = ( + tensor_parallel.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + None, + None, + None, + None, + rotary_pos_emb, + ) + ) else: - if self.transformer_impl == 'transformer_engine': + if self.transformer_impl == "transformer_engine": hidden_states, *local_moe_losses = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + **te_forward_kwargs, + ) else: hidden_states, *local_moe_losses = custom(l, l + 1)( - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) - + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + None, + None, + None, + None, + rotary_pos_emb, + ) + moe_losses.extend(local_moe_losses) else: raise ValueError("Invalid activation recompute method.") @@ -1883,19 +2263,25 @@ def set_input_tensor(self, input_tensor): forward_step_func""" self.input_tensor = input_tensor - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + ): # hidden_states: [s, b, h] # Checks. if inference_params: - assert self.recompute_granularity is None, \ - 'inference does not work with activation checkpointing' + assert ( + self.recompute_granularity is None + ), "inference does not work with activation checkpointing" # TODO: Below old DeepSpeed code are commented because it's unsure whether # it is still relevant. @@ -1950,64 +2336,77 @@ def forward(self, hidden_states, attention_mask, with rng_context: # The fp8_autocast context manager is a no-op when enabled=True # The if...else serves to short circuit name resolution for fp8_autocast - with transformer_engine.pytorch.fp8_autocast( - enabled=self.use_fp8, - fp8_recipe=self.fp8_recipe, - fp8_group=self.fp8_group - ) if self.use_fp8 else nullcontext(): + with ( + transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group, + ) + if self.use_fp8 + else nullcontext() + ): # Determine if the current iteration is first microbatch if self.num_microbatches_in_previous_step != get_num_microbatches(): - self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.microbatch_count = ( + 0 # Reset count on new batch size rampup interval + ) self.num_microbatches_in_previous_step = get_num_microbatches() - is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + is_first_microbatch = ( + self.microbatch_count % get_num_microbatches() == 0 + ) # Forward pass. moe_losses = [] if self.checkpoint_activations: - hidden_states, moe_losses = self._checkpointed_forward(hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - rotary_pos_emb, - is_first_microbatch) - elif self.recompute_granularity == 'full': - hidden_states, moe_losses = self._checkpointed_forward(hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - rotary_pos_emb, - is_first_microbatch) + hidden_states, moe_losses = self._checkpointed_forward( + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch, + ) + elif self.recompute_granularity == "full": + hidden_states, moe_losses = self._checkpointed_forward( + hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch, + ) else: forward_kwargs = { - 'encoder_output': encoder_output, - 'enc_dec_attn_mask': enc_dec_attn_mask, - 'inference_params': inference_params, + "encoder_output": encoder_output, + "enc_dec_attn_mask": enc_dec_attn_mask, + "inference_params": inference_params, } - if self.transformer_impl == 'transformer_engine': - forward_kwargs['is_first_microbatch'] = is_first_microbatch - forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_impl == "transformer_engine": + forward_kwargs["is_first_microbatch"] = is_first_microbatch + forward_kwargs["checkpoint_core_attention"] = ( + self.checkpoint_core_attention + ) if self.transformer_engine_rope_available: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs["rotary_pos_emb"] = rotary_pos_emb else: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - forward_kwargs['retriever_input'] = retriever_input - forward_kwargs['retriever_output'] = retriever_output - forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + forward_kwargs["rotary_pos_emb"] = rotary_pos_emb + forward_kwargs["retriever_input"] = retriever_input + forward_kwargs["retriever_output"] = retriever_output + forward_kwargs["retriever_attn_mask"] = retriever_attn_mask for index in range(self.num_layers): layer = self._get_layer(index) hidden_states = layer( - hidden_states, - attention_mask, - **forward_kwargs) + hidden_states, attention_mask, **forward_kwargs + ) # First Retro decoder layer returns both hidden_states # and retriever_output. Make retriever_output available # to subsequence Retro layers. if isinstance(hidden_states, tuple): - assert (len(hidden_states) == 2 or len(hidden_states) == 3) + assert len(hidden_states) == 2 or len(hidden_states) == 3 if len(hidden_states) == 2: if not self.ds_inference: hidden_states, moe_loss = hidden_states @@ -2033,6 +2432,7 @@ def forward(self, hidden_states, attention_mask, return (hidden_states, *moe_losses) + class LMHeadPipe(MegatronModule): """ Arguments: @@ -2046,11 +2446,13 @@ class LMHeadPipe(MegatronModule): def __init__(self, hidden_size, vocab_size, config): args = get_args() super(LMHeadPipe, self).__init__() - self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=hidden_size, - output_size=vocab_size, - bias=False, - config=config, - init_method=config.init_method,) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=hidden_size, + output_size=vocab_size, + bias=False, + config=config, + init_method=config.init_method, + ) def forward(self, inputs, **kwargs): assert torch.is_tensor(inputs) or isinstance(inputs, tuple) @@ -2059,10 +2461,10 @@ def forward(self, inputs, **kwargs): else: hidden_states = inputs - if not hasattr(self, '_args'): + if not hasattr(self, "_args"): self._args = get_args() - if hasattr(self._args, 'attn_mask'): + if hasattr(self._args, "attn_mask"): attention_mask = None else: attention_mask = inputs[1] @@ -2070,7 +2472,7 @@ def forward(self, inputs, **kwargs): logits, _ = self.lm_head(hidden_states) # If cmd args has attn_mask, we don't forward it as an activation. - if hasattr(self._args, 'attn_mask'): + if hasattr(self._args, "attn_mask"): return logits else: return logits, attention_mask diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 48f2737a06..99145ff4f4 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -315,6 +315,15 @@ def optimizer_hook(p): weight_decay=args.weight_decay, momentum=args.sgd_momentum ) + elif str(args.optimizer).lower() == 'sophiag': + from .sophia import SophiaG + optimizer = SophiaG( + param_groups, + lr=args.lr, + betas=(args.sophiag_beta1, args.sophiag_beta2), + rho = args.sophiag_rho, + weight_decay=args.weight_decay + ) else: raise TypeError(f'{args.optimizer} optimizer is not supported.') if args.deepspeed: diff --git a/megatron/optimizer/sophia.py b/megatron/optimizer/sophia.py new file mode 100644 index 0000000000..4c4e074790 --- /dev/null +++ b/megatron/optimizer/sophia.py @@ -0,0 +1,202 @@ +import math +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Optional + + +#SOphiaG implementation from https://github.com/Liuhong99/Sophia/blob/main/sophia.py, copy pasted here because no pip and not sure about submodules + +class SophiaG(Optimizer): + def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04, + weight_decay=1e-1, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = - lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) diff --git a/megatron/profiler.py b/megatron/profiler.py new file mode 100644 index 0000000000..aeab144846 --- /dev/null +++ b/megatron/profiler.py @@ -0,0 +1,56 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import torch + +on_step_begin = [] +on_step_end = [] + +def trigger(phase): + [f() for f in phase] + +def setup_profiler(args, device): + if args.profile is None: + return + + start_step, end_step = map(int, args.profile_steps.split(',')) + active_steps = end_step - start_step + 1 + cur_step = 0 + + def on_step_begin_fn(): + nonlocal cur_step + cur_step = cur_step + 1 + on_step_begin.append(on_step_begin_fn) + + def when(cond, clbk): + def fn(): + if cond(): + clbk() + return fn + + def is_start_step(): + return cur_step == start_step + + def is_end_step(): + return cur_step == end_step + + def is_capture_step(): + return cur_step >= start_step and cur_step <= end_step + + if args.profile.startswith('pt') and ( + args.profile_ranks is None or torch.distributed.get_rank() in args.profile_ranks + ): + schedule = torch.profiler.schedule(wait=0, warmup=0, active=active_steps, repeat=1) + activities = [torch.profiler.ProfilerActivity.CPU] + activities.extend([torch.profiler.ProfilerActivity.HPU] if device.startswith("hpu") else []) + activities.extend([torch.profiler.ProfilerActivity.CUDA] if device.startswith("cuda") else []) + full = args.profile == 'pt-full' + + profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir, use_gzip=True), + with_stack=full) + + on_step_begin.append(when(is_start_step, profiler.start)) + on_step_end.append(when(is_capture_step, profiler.step)) + on_step_end.append(when(is_end_step, profiler.stop)) diff --git a/megatron/timers.py b/megatron/timers.py index 870ba8996f..5a9ba1f21d 100644 --- a/megatron/timers.py +++ b/megatron/timers.py @@ -39,11 +39,10 @@ def elapsed(self, reset=True, barrier=False): pass - class DummyTimer(TimerBase): def __init__(self): - super().__init__('dummy timer') + super().__init__("dummy timer") def start(self, barrier=False): return @@ -55,9 +54,7 @@ def reset(self): return def elapsed(self, reset=True, barrier=False): - raise Exception('dummy timer should not be used to ' - 'calculate elapsed time') - + raise Exception("dummy timer should not be used to " "calculate elapsed time") class Timer(TimerBase): @@ -78,37 +75,32 @@ def __init__(self, name): self._barrier_group = None self._start_time = time.time() - def set_barrier_group(self, barrier_group): self._barrier_group = barrier_group - def start(self, barrier=False): """Start the timer.""" - assert not self._started, 'timer has already been started' + assert not self._started, "timer has already been started" if barrier: torch.distributed.barrier(group=self._barrier_group) get_accelerator().synchronize() self._start_time = time.time() self._started = True - def stop(self, barrier=False): """Stop the timer.""" - assert self._started, 'timer is not started' + assert self._started, "timer is not started" if barrier: torch.distributed.barrier(group=self._barrier_group) get_accelerator().synchronize() - self._elapsed += (time.time() - self._start_time) + self._elapsed += time.time() - self._start_time self._started = False - def reset(self): """Reset timer.""" self._elapsed = 0.0 self._started = False - def elapsed(self, reset=True, barrier=False): """Calculate the elapsed time.""" _started = self._started @@ -126,7 +118,6 @@ def elapsed(self, reset=True, barrier=False): return _elapsed - class Timers: """Group of timers.""" @@ -138,24 +129,27 @@ def __init__(self, log_level, log_option): self._dummy_timer = DummyTimer() self._max_log_level = 2 - def __call__(self, name, log_level=None): # If the timer has already been set, then check if the log-level # is provided, it matches the one that the timer was created with. if name in self._timers: if log_level is not None: - assert log_level == self._log_levels[name], \ - 'input log level {} does not match already existing '\ - 'log level {} for {} timer'.format( - log_level, self._log_levels[name], name) + assert log_level == self._log_levels[name], ( + "input log level {} does not match already existing " + "log level {} for {} timer".format( + log_level, self._log_levels[name], name + ) + ) return self._timers[name] # If timer does not exist and no log level is provided, # set it to the max log level which is 2. if log_level is None: log_level = self._max_log_level - assert log_level <= self._max_log_level, \ - 'log level {} is larger than max supported log level {}'.format( - log_level, self._max_log_level) + assert ( + log_level <= self._max_log_level + ), "log level {} is larger than max supported log level {}".format( + log_level, self._max_log_level + ) # Now if the input log level is larger than the one set for # the timers class, just ignore it and return a dummy timer. if log_level > self._log_level: @@ -165,7 +159,6 @@ def __call__(self, name, log_level=None): self._log_levels[name] = log_level return self._timers[name] - def _get_elapsed_time_all_ranks(self, names, reset, barrier): """ Assumptions: @@ -191,34 +184,35 @@ def _get_elapsed_time_all_ranks(self, names, reset, barrier): # pytorch yet. It is simpler to deal with a single tensor # and since we are only gathering a small amount of data, # it should be ok to use all-gather instead of gather. - rank_name_to_time = torch.zeros((world_size, len(names)), - dtype=torch.float, - device=get_accelerator().current_device_name()) + rank_name_to_time = torch.zeros( + (world_size, len(names)), + dtype=torch.float, + device=get_accelerator().current_device_name(), + ) for i, name in enumerate(names): if name in self._timers: # Here we don't need to pass the barrier flag as all # the processes are already in sync. This avoids the # issue of different timers having different barrier # groups inside their class. - rank_name_to_time[rank, i] = self._timers[name].elapsed( - reset=reset) + rank_name_to_time[rank, i] = self._timers[name].elapsed(reset=reset) # See the note above for why we are not using gather. - if version.parse(torch.__version__) >= version.parse('1.13'): - torch.distributed.all_gather_into_tensor(rank_name_to_time.view(-1), - rank_name_to_time[rank, :].view(-1)) + if version.parse(torch.__version__) >= version.parse("1.13"): + torch.distributed.all_gather_into_tensor( + rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1) + ) else: - torch.distributed._all_gather_base(rank_name_to_time.view(-1), - rank_name_to_time[rank, :].view(-1)) + torch.distributed._all_gather_base( + rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1) + ) return rank_name_to_time - def _get_global_min_max_time(self, names, reset, barrier, normalizer): """Report only min and max times across all ranks.""" - rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, - barrier) + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier) name_to_min_max_time = {} for i, name in enumerate(names): rank_to_time = rank_name_to_time[:, i] @@ -228,34 +222,36 @@ def _get_global_min_max_time(self, names, reset, barrier, normalizer): if rank_to_time.numel() > 0: name_to_min_max_time[name] = ( rank_to_time.min().item() / normalizer, - rank_to_time.max().item() / normalizer) + rank_to_time.max().item() / normalizer, + ) return name_to_min_max_time - - def _get_global_min_max_time_string(self, names, reset, barrier, - normalizer, max_only): + def _get_global_min_max_time_string( + self, names, reset, barrier, normalizer, max_only + ): name_to_min_max_time = self._get_global_min_max_time( - names, reset, barrier, normalizer) + names, reset, barrier, normalizer + ) if not name_to_min_max_time: return None - output_string = '(min, max) time across ranks (ms):' + output_string = "(min, max) time across ranks (ms):" for name in name_to_min_max_time: min_time, max_time = name_to_min_max_time[name] if max_only: - output_string += '\n {}: {:.2f}'.format( - (name+' ').ljust(48, '.'), max_time) + output_string += "\n {}: {:.2f}".format( + (name + " ").ljust(48, "."), max_time + ) else: - output_string += '\n {}: ({:.2f}, {:.2f})'.format( - (name+' ').ljust(48, '.'), min_time, max_time) + output_string += "\n {}: ({:.2f}, {:.2f})".format( + (name + " ").ljust(48, "."), min_time, max_time + ) return output_string - def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): """Report times across all ranks.""" - rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, - barrier) + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier) - output_string = 'times across ranks (ms):' + output_string = "times across ranks (ms):" no_reported_timing = True for i, name in enumerate(names): not_yet_found = True @@ -264,32 +260,32 @@ def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): no_reported_timing = False if not_yet_found: not_yet_found = False - output_string += '\n {}:'.format(name) - output_string += '\n rank {:2d}: {:.2f}'.format( - rank, rank_name_to_time[rank, i] / normalizer) + output_string += "\n {}:".format(name) + output_string += "\n rank {:2d}: {:.2f}".format( + rank, rank_name_to_time[rank, i] / normalizer + ) if no_reported_timing: return None return output_string - def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): """Log a group of timers.""" # Print. assert normalizer > 0.0 - if self._log_option in ['max', 'minmax']: + if self._log_option in ["max", "minmax"]: max_only = False - if self._log_option == 'max': + if self._log_option == "max": max_only = True output_string = self._get_global_min_max_time_string( - names, reset, barrier, normalizer/1000.0, max_only) - elif self._log_option == 'all': - output_string = self._get_all_ranks_time_string(names, - reset, barrier, - normalizer/1000.0) + names, reset, barrier, normalizer / 1000.0, max_only + ) + elif self._log_option == "all": + output_string = self._get_all_ranks_time_string( + names, reset, barrier, normalizer / 1000.0 + ) else: - raise Exception('unknown timing log option {}'.format( - self._log_option)) + raise Exception("unknown timing log option {}".format(self._log_option)) # If no input rank is provided, log on last rank. if rank is None: @@ -297,15 +293,14 @@ def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): if rank == torch.distributed.get_rank() and output_string is not None: print(output_string, flush=True) - def write( - self, - names: list[str], - writer: Writer, - iteration: int, - normalizer: float = 1.0, - reset: bool = False, - barrier: bool = False + self, + names: list[str], + writer: Writer, + iteration: int, + normalizer: float = 1.0, + reset: bool = False, + barrier: bool = False, ): """Write timers to a tensorboard writer Note that we only report maximum time across ranks to tensorboard. @@ -315,17 +310,22 @@ def write( # polutes the runs list, so we just add each as a scalar assert normalizer > 0.0 name_to_min_max_time = self._get_global_min_max_time( - names, reset, barrier, normalizer) + names, reset, barrier, normalizer + ) + # <<<<<<< HEAD timer_data = { - 'timers/iteration': iteration, + "timers/iteration": iteration, **{ - f'timers/{k}-time': name_to_min_max_time[k][1] + f"timers/{k}-time": name_to_min_max_time[k][1] for k in name_to_min_max_time - } + }, } - if wandb is not None and getattr(wandb, 'run', None) is not None: + if wandb is not None and getattr(wandb, "run", None) is not None: wandb.log(timer_data, commit=False) - if writer is not None: + # ======= + # if writer.is_enabled(): + # >>>>>>> 0d6e3793a1fc06eded9764ef15ad12bcc0281101 + if writer is not None: # and writer.is_enabled(): for name in name_to_min_max_time: _, max_time = name_to_min_max_time[name] - writer.add_scalar(f'{name}-time', max_time, iteration) + writer.add_scalar(f"{name}-time", max_time, iteration) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 023c2f756c..92853fb30d 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Megatron tokenizers.""" @@ -45,11 +46,14 @@ def build_tokenizer(args): tokenizer = _NullTokenizer(args.vocab_size) elif args.tokenizer_type == 'HFTokenizer': assert args.tokenizer_model is not None - tokenizer = _HFTokenizer(args.tokenizer_model) + # tokenizer = _HFTokenizer(args.tokenizer_model) + tokenizer = _HFTokenizer(args.tokenizer_model, + args.seq_length, + args.trust_remote_code) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) - + # Add vocab size. args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) @@ -595,10 +599,30 @@ def additional_special_tokens_ids(self): class _HFTokenizer(AbstractTokenizer): """HF Tokenizer""" - def __init__(self, tokenizer_name_or_path): + def __init__(self, tokenizer_name_or_path, max_seq_len, trust_remote_code): name = tokenizer_name_or_path super().__init__(name) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, + padding_side="right", + trust_remote_code=trust_remote_code, + use_fast=False) + DEFAULT_PAD_TOKEN = "[PAD]" + DEFAULT_EOS_TOKEN = "" + DEFAULT_BOS_TOKEN = "" + DEFAULT_UNK_TOKEN = "" + special_tokens_dict = dict() + if self.tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN + if self.tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN + if self.tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN + if self.tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN + self.tokenizer.add_special_tokens(special_tokens_dict) + # if self.tokenizer.pad_token == None: + # self.tokenizer.pad_token= "[PAD]" + self.tokenizer.model_max_length = max_seq_len self.encoder = self.tokenizer.get_vocab() self.decoder = {v: k for k, v in self.encoder.items()} diff --git a/megatron/training.py b/megatron/training.py index 9289c64ec8..8ffac6cb9c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1,66 +1,77 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain utilities.""" +import time +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() + +from collections import OrderedDict from datetime import datetime +import json +import logging import math -import sys import os +import sys import time -import json -# noqa: E402 -# The earliest we can measure the start time. -_TRAIN_START_TIME = time.time() +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.compression.compress import init_compression, redundancy_clean +from deepspeed.runtime.data_pipeline.data_routing.helper import ( + convert_to_random_ltd, +) +import ezpz as ez import torch import torch.distributed as tdist -from collections import OrderedDict from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -from megatron import get_args -from megatron import get_signal_handler -from megatron import get_timers -from megatron import get_tensorboard_writer -from megatron import get_current_global_batch_size -from megatron import get_num_microbatches -from megatron import is_last_rank -from megatron import update_num_microbatches +from megatron import ( + get_args, + get_current_global_batch_size, + get_num_microbatches, + get_signal_handler, + get_tensorboard_writer, + get_timers, + is_last_rank, + update_num_microbatches, +) +from megatron.arguments import core_transformer_config_from_args +from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.core import mpu, tensor_parallel -# from megatron import print_rank_0, is_rank_0 -# from megatron import print_rank_last -from megatron.checkpointing import load_checkpoint -from megatron.checkpointing import save_checkpoint -from megatron.model import Float16Module -from megatron.model import GPTModel from megatron.core.enums import ModelType +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.initialize import ( + initialize_megatron, + set_jit_fusion_options, + write_args_to_tensorboard, +) +from megatron.model import Float16Module, GPTModel +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model.transformer import ParallelTransformerLayer +from megatron.model.vision.knn_monitor import compute_feature_bank from megatron.optimizer import get_megatron_optimizer -from megatron.initialize import initialize_megatron -from megatron.initialize import write_args_to_tensorboard -from megatron.initialize import set_jit_fusion_options from megatron.optimizer_param_scheduler import OptimizerParamScheduler -from megatron.model import DistributedDataParallel as LocalDDP -from megatron.utils import check_adlr_autoresume_termination -from megatron.utils import unwrap_model -from megatron.data.data_samplers import build_pretraining_data_loader -from megatron.utils import calc_params_l2_norm -from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.training_log import training_log from megatron.utils import ( - report_memory, - throughput_calculator, + PerfTrace, + Profile, + calc_params_l2_norm, + check_adlr_autoresume_termination, checkpoint_throughput_calculator, + found_kill_switch, + unwrap_model, update_rotary_pos_emb, ) -from megatron.model.vision.knn_monitor import compute_feature_bank -from megatron.arguments import core_transformer_config_from_args -from megatron.utils import PerfTrace, Profile +from megatron.profiler import ( + setup_profiler, + trigger, + on_step_begin, + on_step_end, +) -import deepspeed -from deepspeed.accelerator import get_accelerator -from deepspeed.compression.compress import init_compression, redundancy_clean -from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd -from megatron.model.transformer import ParallelTransformerLayer -import ezpz as ez -import logging dlp = Profile("TRAINING") @@ -76,11 +87,6 @@ LOG_LEVEL: str = str(os.environ.get("LOG_LEVEL", "INFO")).upper() log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL") -try: - import wandb -except (ImportError, ModuleNotFoundError): - wandb = None - def print_datetime(string): """Note that this call will sync across all ranks.""" @@ -89,37 +95,6 @@ def print_datetime(string): log.info("[" + string + "] datetime={} ".format(time_str)) -def num_floating_point_operations(args, batch_size): - # Group Query Attention. - # if not args.group_query_attention: - if not args.num_key_value_heads: - args.num_key_value_heads = args.num_attention_heads - # args.num_query_groups = args.num_attention_heads - # MoE. - # num_experts_routed_to = 1 if args.num_experts is None else args.moe_router_topk - num_experts_routed_to = 1 if args.num_experts is None else args.topk - gated_linear_multiplier = 3 / 2 if args.swiglu else 1 - return ( - 12 - * batch_size - * args.seq_length - * args.num_layers - * args.hidden_size - * args.hidden_size - * ( - 1 - + ( - (args.ffn_hidden_size / args.hidden_size) - * num_experts_routed_to - * gated_linear_multiplier - ) - + (args.num_key_value_heads / args.num_attention_heads) - + (args.seq_length / args.hidden_size) - + (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size)) - ) - ) - - """ Since v0.9.0, deepspeed.initialize() has forbidden simultaneous setting of args.deepspeed_config (Path) and ds_config dict. So, we use ds_config dict which is the more flexible option @@ -143,15 +118,15 @@ def _create_ds_config_dict(): @dlp.log def pretrain( - train_valid_test_dataset_provider, - model_provider, - model_type, - forward_step_func, - process_non_loss_data_func=None, - extra_args_provider=None, - args_defaults={}, - data_post_process=None, - external_args={}, + train_valid_test_dataset_provider, + model_provider, + model_type, + forward_step_func, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults={}, + data_post_process=None, + external_args={}, ) -> list[torch.nn.Module]: """Main training program. @@ -190,9 +165,15 @@ def pretrain( args_defaults=args_defaults, external_args=external_args, ) + args = get_args() + assert args is not None + if found_kill_switch(): + print_datetime(f"Detected kill switch at {args.kill_switch_file}. Exiting") + sys.exit() + # Set pytorch JIT layer fusion options and warmup JIT functions. # if get_accelerator().device_name() == "cuda": - if DEVICE_TYPE == 'cuda' and torch.cuda.is_available(): + if DEVICE_TYPE == "cuda" and torch.cuda.is_available(): set_jit_fusion_options() # Adjust the startup time so it reflects the largest value. @@ -204,7 +185,9 @@ def pretrain( f"time to finish initialize_megatron: {time.time() - _TRAIN_START_TIME} seconds" ) # start_time_tensor = DEVICE.DoubleTensor([_TRAIN_START_TIME]) - start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device=DEVICE_TYPE) + start_time_tensor = torch.tensor( + [_TRAIN_START_TIME], dtype=torch.double, device=DEVICE_TYPE + ) tdist.all_reduce(start_time_tensor, op=tdist.ReduceOp.MIN) log.info(f"allreduce call time: {time.time()-before_allreduce} seconds") _TRAIN_START_TIME = start_time_tensor.item() @@ -214,8 +197,6 @@ def pretrain( ) ) print_datetime("after megatron is initialized") - args = get_args() - assert args is not None if os.getenv("DLIO_PROFILER_DATASET_DIR") is not None: extra_trace_path = os.environ["DLIO_PROFILER_DATASET_DIR"] else: @@ -429,6 +410,8 @@ def get_model( ): """Build the model.""" args = get_args() + accelerator = get_accelerator() + assert accelerator is not None assert args is not None args.model_type = model_type @@ -535,7 +518,7 @@ def get_model( if wrap_with_ddp: if args.DDP_impl == "torch": - i = get_accelerator().current_device() + i = accelerator.current_device() model = [ torchDDP( model_module, @@ -561,9 +544,8 @@ def get_model( model_module.broadcast_params() else: raise NotImplementedError( - "Unknown DDP implementation specified: " "{}. Exiting.".format( - args.DDP_impl - ) + "Unknown DDP implementation specified: " + "{}. Exiting.".format(args.DDP_impl) ) return model @@ -639,18 +621,13 @@ def load_model_weights_only(model_provider_func): ) assert not isinstance(model, deepspeed.PipelineEngine), ( - "Weight loading only mode is not supported in " - "pipeline parallelism yet." + "Weight loading only mode is not supported in " "pipeline parallelism yet." ) model = [model] print_datetime("before load checkpoint") if args.load is not None: - iteration = load_checkpoint( - model, - optimizer, - lr_scheduler, - strict=True, - load_only_weights=True + _ = load_checkpoint( + model, optimizer, lr_scheduler, strict=True, load_only_weights=True ) print_datetime("after load checkpoint weights") return model, optimizer, lr_scheduler @@ -659,14 +636,14 @@ def load_model_weights_only(model_provider_func): @dlp.log @ez.dist.timeitlogit(rank=RANK) def setup_model_and_optimizer( - model_provider_func, - model_type, - no_wd_decay_cond=None, - scale_lr_cond=None, - lr_mult=1.0, - teacher=False, - data_post_process=None, - build_train_valid_test_datasets_provider=None, + model_provider_func, + model_type, + no_wd_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0, + teacher=False, + data_post_process=None, + build_train_valid_test_datasets_provider=None, ): """Setup model and optimizer.""" args = get_args() @@ -714,8 +691,8 @@ def setup_model_and_optimizer( log.info("DeepSpeed is enabled.") # pp = mpu.get_pipeline_model_parallel_world_size() if ( - args.data_efficiency_curriculum_learning - and build_train_valid_test_datasets_provider is not None + args.data_efficiency_curriculum_learning + and build_train_valid_test_datasets_provider is not None ): log.info( "Caught 'args.data_efficiency_curriculum_learning' " @@ -725,16 +702,14 @@ def setup_model_and_optimizer( # Only need to build dataset on tp rank 0 since Megatron has the # broadcast_data() function that broadcast data from tp rank 0. if mpu.get_tensor_model_parallel_rank() == 0: - log.info( - f"Caught 'mpu.get_tensor_model_parallel_rank() == 0'" - ) + log.info("Caught 'mpu.get_tensor_model_parallel_rank() == 0'") # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size - log.info(f'{train_samples=}') + log.info(f"{train_samples=}") # eval_iters and test_iters here are not actually used, only for # satisfying the input of build_train_valid_test_datasets_provider. # We only need to build the training data here. And we follow @@ -774,9 +749,7 @@ def setup_model_and_optimizer( ) tds0 = time.time() if os.environ.get("PYINSTRUMENT_PROFILER", None): - profiler = ez.profile.get_context_manager( - rank=RANK, outdir=args.save - ) + profiler = ez.profile.get_context_manager(rank=RANK, outdir=args.save) else: profiler = Profile("deepspeed.initialize") log.info("Calling 'deepspeed.initialize'...") @@ -832,6 +805,7 @@ def setup_model_and_optimizer( log.info("Initializing ICT from pretrained BERT model") unwrapped_model[0].init_state_dict_from_bert() if args.fp16: + assert optimizer is not None optimizer.reload_model_params() # random-LTD requires converting transformer layers if args.random_ltd: @@ -846,10 +820,11 @@ def train_step( """Single training step.""" args = get_args() timers = get_timers() - - assert args is not None and timers is not None + accelerator = get_accelerator() + assert args is not None and timers is not None and accelerator is not None + grad_norm = None + num_zeros_in_grad = None if args.deepspeed and args.ds_pipeline_enabled: - skipped_iter = 0 num_zeros_in_grad = 0 assert isinstance(model[0], deepspeed.PipelineEngine) loss = model[0].train_batch(data_iter=data_iterator) @@ -861,6 +836,8 @@ def train_step( if additional_losses is not None: loss_dict.update(additional_losses) grad_norm = model[0].get_global_grad_norm() + update_successful = model[0].was_step_applied() + skipped_iter = 0 if update_successful else 1 return loss_dict, skipped_iter, grad_norm, num_zeros_in_grad # Set grad to zero. @@ -883,11 +860,13 @@ def train_step( if args.timing_log_level < 2: config.timers = None + num_microbatches = get_num_microbatches() + assert num_microbatches is not None losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -902,8 +881,8 @@ def train_step( args.teacher_forward = False # Empty unused memory. - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() + if args.empty_unused_memory_level >= 1 and accelerator is not None: + accelerator.empty_cache() # Reduce gradients. if not args.deepspeed: @@ -940,8 +919,18 @@ def train_step( # Update learning rate. if args.deepspeed: - skipped_iter = 0 + skipped_iter = 0 if update_successful else 1 grad_norm = model[0].get_global_grad_norm() + # Empty unused memory. + if args.empty_unused_memory_level >= 2 and accelerator is not None: + accelerator.empty_cache() + + # XXX: [saforem2]: ---------------------------------------------------- + # Is `num_zeros_in_grad` worth calculating (/ implementing) ?? + # the `Megatron`-specific implementation is at: + # [megatron.optimizer.clip_grads.count_zeros_fp32](./optimizer/clip_grads.py) + # For now, explicitly set to None + # --------------------------------------------------------------------- num_zeros_in_grad = None loss_reduced = {} for key in losses_reduced[0]: @@ -950,650 +939,29 @@ def train_step( losses_reduced_for_key ) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad - else: - if update_successful: - increment = ( - get_num_microbatches() * args.micro_batch_size * args.data_parallel_size - ) - opt_param_scheduler.step(increment=increment) - skipped_iter = 0 - else: - skipped_iter = 1 - - # Empty unused memory. - if args.empty_unused_memory_level >= 2: - torch.cuda.empty_cache() - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Average loss across microbatches. - loss_reduced = {} - for key in losses_reduced[0]: - losses_reduced_for_key = [x[key] for x in losses_reduced] - loss_reduced[key] = sum(losses_reduced_for_key) / len( - losses_reduced_for_key - ) - return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad - return {}, skipped_iter, grad_norm, num_zeros_in_grad - - -@dlp.log -def training_log( - loss_dict, - total_loss_dict, - learning_rate, - iteration, - loss_scale, - report_memory_flag, - skipped_iter, - grad_norm, - params_norm, - num_zeros_in_grad, - model=None, - optimizer=None, -): - """Log training information such as losses, timing, ....""" - args = get_args() - timers = get_timers() - writer = get_tensorboard_writer() - assert args is not None and timers is not None - wandb_metrics = {} - # Advanced, skipped, and Nan iterations. - advanced_iters_key = "advanced iterations" - skipped_iters_key = "skipped iterations" - nan_iters_key = "nan iterations" - # Advanced iterations. - if not skipped_iter: - total_loss_dict[advanced_iters_key] = ( - total_loss_dict.get(advanced_iters_key, 0) + 1 + if update_successful: + increment = ( + get_num_microbatches() * args.micro_batch_size * args.data_parallel_size ) + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 else: - if advanced_iters_key not in total_loss_dict: - total_loss_dict[advanced_iters_key] = 0 - # Skipped iterations. - total_loss_dict[skipped_iters_key] = ( - total_loss_dict.get(skipped_iters_key, 0) + skipped_iter - ) - # Update losses and set nan iterations - got_nan = False - for key in loss_dict: - if not skipped_iter: - total_loss_dict[key] = ( - total_loss_dict.get(key, get_accelerator().FloatTensor([0.0])) - + loss_dict[key] - ) - else: - value = loss_dict[key].float().sum().item() - is_nan = value == float("inf") or value == -float("inf") or value != value - got_nan = got_nan or is_nan - total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int( - got_nan - ) - - # Logging. - timers_to_log = [ - "forward-backward", - "forward-compute", - "backward-compute", - "batch-generator", - "forward-recv", - "forward-send", - "backward-recv", - "backward-send", - "forward-send-forward-recv", - "forward-send-backward-recv", - "backward-send-forward-recv", - "backward-send-backward-recv", - "forward-backward-send-forward-backward-recv", - "layernorm-grads-all-reduce", - "embedding-grads-all-reduce", - "grads-all-reduce", - "grads-reduce-scatter", - "params-all-gather", - "optimizer-copy-to-main-grad", - "optimizer-unscale-and-check-inf", - "optimizer-clip-main-grad", - "optimizer-count-zeros", - "optimizer-inner-step", - "optimizer-copy-main-to-model-params", - "optimizer", - ] + skipped_iter = 1 - # Calculate batch size. - batch_size = ( - args.micro_batch_size * args.data_parallel_size * get_num_microbatches() - ) - total_iterations = ( - total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] - ) - - # Tensorboard values. - # Timer requires all the ranks to call. - if args.log_timers_to_tensorboard and ( - iteration % args.tensorboard_log_interval == 0 - ): - timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) - if writer and (iteration % args.tensorboard_log_interval == 0): - writer.add_scalar( - "steps-vs-samples/y=steps,x=samples", iteration, args.consumed_train_samples - ) - writer.add_scalar( - "steps-vs-samples/y=samples,x=steps", args.consumed_train_samples, iteration - ) - writer.add_scalar( - "steps-vs-tokens/y=steps,x=tokens", iteration, args.consumed_train_tokens - ) - writer.add_scalar( - "steps-vs-tokens/y=tokens,x=steps", args.consumed_train_tokens, iteration - ) - if args.log_learning_rate_to_tensorboard: - wandb_metrics |= { - "learning-rate/iteration": iteration, - "learning-rate/learning-rate": learning_rate, - } - writer.add_scalar("learning-rate/learning-rate", learning_rate, iteration) - writer.add_scalar( - "learning-rate/learning-rate vs samples", - learning_rate, - args.consumed_train_samples, - ) - writer.add_scalar( - "learning-rate/learning-rate vs tokens", - learning_rate, - args.consumed_train_tokens, - ) - if args.log_batch_size_to_tensorboard: - writer.add_scalar("batch-size/batch-size", batch_size, iteration) - writer.add_scalar( - "batch-size/batch-size vs samples", - batch_size, - args.consumed_train_samples, - ) - writer.add_scalar( - "batch-size/batch-size vs tokens", - batch_size, - args.consumed_train_tokens, - ) - wandb_metrics |= { - "lm-loss-training/iteration": iteration, - "lm-loss-training/consumed_train_tokens": args.consumed_train_tokens, - } - for key in loss_dict: - wandb_metrics |= {f"lm-loss-training/{key}": loss_dict[key]} - writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) - writer.add_scalar( - f"lm-loss-training/{key}" + " vs samples", - loss_dict[key], - args.consumed_train_samples, - ) - writer.add_scalar( - f"lm-loss-training/{key}" + " vs tokens", - loss_dict[key], - args.consumed_train_tokens, - ) - if args.fp16 and loss_scale and args.log_loss_scale_to_tensorboard: - writer.add_scalar("loss-scale/loss-scale", loss_scale, iteration) - writer.add_scalar( - "loss-scale/loss-scale vs samples", - loss_scale, - args.consumed_train_samples, - ) - writer.add_scalar( - "loss-scale/loss-scale vs tokens", - loss_scale, - args.consumed_train_tokens, - ) - if args.log_world_size_to_tensorboard: - writer.add_scalar("world-size/world-size", args.world_size, iteration) - writer.add_scalar( - "world-size/world-size vs samples", - args.world_size, - args.consumed_train_samples, - ) - writer.add_scalar( - "world-size/world-size vs tokens", - args.world_size, - args.consumed_train_tokens, - ) - if grad_norm is not None: - wandb_metrics |= {"training/grad-norm": grad_norm} - writer.add_scalar("grad-norm/grad-norm", grad_norm, iteration) - writer.add_scalar( - "grad-norm/grad-norm vs samples", grad_norm, args.consumed_train_samples - ) - writer.add_scalar( - "grad-norm/grad-norm vs tokens", grad_norm, args.consumed_train_tokens - ) - if num_zeros_in_grad is not None: - wandb_metrics |= {"training/num-zeros": num_zeros_in_grad} - writer.add_scalar("num-zeros/num-zeros", num_zeros_in_grad, iteration) - writer.add_scalar( - "num-zeros/num-zeros vs samples", - num_zeros_in_grad, - args.consumed_train_samples, - ) - writer.add_scalar( - "num-zeros/num-zeros vs tokens", - num_zeros_in_grad, - args.consumed_train_tokens, - ) - if params_norm is not None: - wandb_metrics |= {"training/params-norm": params_norm} - writer.add_scalar("params-norm/params-norm", params_norm, iteration) - writer.add_scalar( - "params-norm/params-norm vs samples", - params_norm, - args.consumed_train_samples, - ) - writer.add_scalar( - "params-norm/params-norm vs tokens", - params_norm, - args.consumed_train_tokens, - ) - if hasattr(args, "actual_seq_length"): - writer.add_scalar( - "seqlen/actual_seq_length", args.actual_seq_length, iteration - ) - writer.add_scalar( - "seqlen/actual_seq_length vs samples", - args.actual_seq_length, - args.consumed_train_samples, - ) - writer.add_scalar( - "seqlen/actual_seq_length vs tokens", - args.actual_seq_length, - args.consumed_train_tokens, - ) - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - writer.add_scalar( - "seqlen/curriculum_seqlen", args.curriculum_seqlen, iteration - ) - writer.add_scalar( - "seqlen/curriculum_seqlen vs samples", - args.curriculum_seqlen, - args.consumed_train_samples, - ) - writer.add_scalar( - "seqlen/curriculum_seqlen vs tokens", - args.curriculum_seqlen, - args.consumed_train_tokens, - ) - if args.random_ltd: - writer.add_scalar( - "seqlen/random_ltd_reserved_length", - args.random_ltd_reserved_length, - iteration, - ) - writer.add_scalar( - "seqlen/random_ltd_reserved_length vs samples", - args.random_ltd_reserved_length, - args.consumed_train_samples, - ) - writer.add_scalar( - "seqlen/random_ltd_reserved_length vs tokens", - args.random_ltd_reserved_length, - args.consumed_train_tokens, - ) - if args.log_memory_to_tensorboard: - mem_stats = torch.cuda.memory_stats() - writer.add_scalar( - "mem-reserved-bytes", - mem_stats["reserved_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-bytes", - mem_stats["allocated_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-count", - mem_stats["allocation.all.current"], - iteration, - ) - if iteration % args.tensorboard_log_interval == 0: - # This logging write various optimizer states to tensorboard. This - # feature may consume extra GPU memory thus is set at false by default. - if args.log_optimizer_states_to_tensorboard and optimizer is not None: - opt_stats = [0.0] * 8 - opt_stats_2 = [0.0] * 4 - for _, group in enumerate(optimizer.param_groups): - for _, param in enumerate(group["params"]): - state_param = getattr(optimizer, "state", None) - if state_param is not None: - exp_avg_sq = state_param.get("exp_avg_sq", torch.tensor(0.0)) - exp_avg = state_param.get("exp_avg", torch.tensor(0.0)) - opt_stats[0] += (torch.norm(exp_avg_sq).item()) ** 2 - opt_stats[1] += (torch.norm(exp_avg_sq.sqrt()).item()) ** 2 - opt_stats[2] += (torch.norm(exp_avg).item()) ** 2 - opt_stats[3] += (torch.norm(param).item()) ** 2 - opt_stats[4] += torch.norm(exp_avg_sq, p=1).item() - opt_stats[5] += torch.norm(exp_avg_sq.sqrt(), p=1).item() - opt_stats[6] += torch.norm(exp_avg, p=1).item() - opt_stats[7] += torch.norm(param, p=1).item() - opt_stats_2[0] = max( - opt_stats_2[0], - abs(exp_avg_sq.max().item()), - abs(exp_avg_sq.min().item()), - ) - opt_stats_2[1] = max( - opt_stats_2[1], exp_avg_sq.sqrt().abs_().max().item() - ) - opt_stats_2[2] = max( - opt_stats_2[2], - abs(exp_avg.max().item()), - abs(exp_avg.min().item()), - ) - opt_stats_2[3] = max( - opt_stats_2[3], - abs(param.max().item()), - abs(param.min().item()), - ) - # print('step {} rank {} before sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) - if args.zero_stage > 0: - # ZeRO partiions optimizer states - # opt_stats = opt_stats.clone().detach() - # opt_stats = get_accelerator().FloatTensor - opt_stats = get_accelerator().FloatTensor(opt_stats) - torch.distributed.all_reduce( - opt_stats, group=mpu.get_sequence_data_parallel_group() - ) - # opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) - # opt_stats_2 = opt_stats_2.clone().detach() - opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) - torch.distributed.all_reduce( - opt_stats_2, - op=torch.distributed.ReduceOp.MAX, - group=mpu.get_sequence_data_parallel_group(), - ) - - if args.tensor_model_parallel_size > 1: - # opt_stats = opt_stats.clone().detach() - opt_stats = get_accelerator().FloatTensor(opt_stats) - torch.distributed.all_reduce( - opt_stats, group=mpu.get_tensor_model_parallel_group() - ) - # opt_stats_2 = opt_stats_2.clone().detach() - opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) - torch.distributed.all_reduce( - opt_stats_2, - op=torch.distributed.ReduceOp.MAX, - group=mpu.get_tensor_model_parallel_group(), - ) - - if args.pipeline_model_parallel_size > 1: - # opt_stats = opt_stats.clone().detach() - opt_stats = get_accelerator().FloatTensor(opt_stats) - torch.distributed.all_reduce( - opt_stats, group=mpu.get_pipeline_model_parallel_group() - ) - # opt_stats_2 = opt_stats_2.clone().detach() - opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) - torch.distributed.all_reduce( - opt_stats_2, - op=torch.distributed.ReduceOp.MAX, - group=mpu.get_pipeline_model_parallel_group(), - ) - wandb_metrics |= { - "optimizer/learning_rate": learning_rate, - "optimizer/iteration": args.iteration, - "optimizer/consumed_train_tokens": args.consumed_train_tokens, - "optimizer/variance_l2": opt_stats[0] ** 0.5, - "optimizer/variance_sqrt_l2": opt_stats[1] ** 0.5, - "optimizer/momentum_l2": opt_stats[2] ** 0.5, - "optimizer/weight_l2": opt_stats[3] ** 0.5, - "optimizer/variance_l1": opt_stats[4], - "optimizer/variance_sqrt_l1": opt_stats[5], - "optimizer/momentum_l1": opt_stats[6], - "optimizer/weight_l1": opt_stats[7], - "optimizer/variance_abs_max": opt_stats_2[0], - "optimizer/variance_sqrt_abs_max": opt_stats_2[1], - "optimizer/momentum_abs_max": opt_stats_2[2], - "optimizer/weight_abs_max": opt_stats_2[3], - } - # print('step {} rank {} after sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) - if writer and is_last_rank(): - writer.add_scalar( - "optimizer/variance_l2 vs tokens", - opt_stats[0] ** 0.5, - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_sqrt_l2 vs tokens", - opt_stats[1] ** 0.5, - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/momentum_l2 vs tokens", - opt_stats[2] ** 0.5, - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/weight_l2 vs tokens", - opt_stats[3] ** 0.5, - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_l1 vs tokens", - opt_stats[4], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_sqrt_l1 vs tokens", - opt_stats[5], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/momentum_l1 vs tokens", - opt_stats[6], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/weight_l1 vs tokens", - opt_stats[7], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_abs_max vs tokens", - opt_stats_2[0], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_sqrt_abs_max vs tokens", - opt_stats_2[1], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/momentum_abs_max vs tokens", - opt_stats_2[2], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/weight_abs_max vs tokens", - opt_stats_2[3], - args.consumed_train_tokens, - ) - writer.add_scalar( - "optimizer/variance_l2", opt_stats[0] ** 0.5, iteration - ) - writer.add_scalar( - "optimizer/variance_sqrt_l2", opt_stats[1] ** 0.5, iteration - ) - writer.add_scalar( - "optimizer/momentum_l2", opt_stats[2] ** 0.5, iteration - ) - writer.add_scalar("optimizer/weight_l2", opt_stats[3] ** 0.5, iteration) - writer.add_scalar("optimizer/variance_l1", opt_stats[4], iteration) - writer.add_scalar("optimizer/variance_sqrt_l1", opt_stats[5], iteration) - writer.add_scalar("optimizer/momentum_l1", opt_stats[6], iteration) - writer.add_scalar("optimizer/weight_l1", opt_stats[7], iteration) - writer.add_scalar( - "optimizer/variance_abs_max", opt_stats_2[0], iteration - ) - writer.add_scalar( - "optimizer/variance_sqrt_abs_max", opt_stats_2[1], iteration - ) - writer.add_scalar( - "optimizer/momentum_abs_max", opt_stats_2[2], iteration - ) - writer.add_scalar("optimizer/weight_abs_max", opt_stats_2[3], iteration) + # Empty unused memory. + if args.empty_unused_memory_level >= 2 and accelerator is not None: + accelerator.empty_cache() - assert args is not None - assert timers is not None - if iteration % args.log_interval == 0: - elapsed_time = timers("interval-time").elapsed(barrier=True) - elapsed_time_per_iteration = elapsed_time / total_iterations - seq_len = args.seq_length - if hasattr(args, "actual_seq_length"): - seq_len = args.actual_seq_length - samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( - model, args, elapsed_time, total_iterations - ) - samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size - tokens_per_sec = samples_per_sec * seq_len - tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size - tokens_per_gpu_per_second = tokens_per_sec / args.world_size - tokens_per_gpu_per_second_per_replica = ( - tokens_per_gpu_per_second / args.data_parallel_size - ) - # NOTE: [2024-06-19] - # Updated to use (more accurate) calculation according to - # `num_floating_point_operations` from NVIDIA/Megatron-LM - num_flop_lm = num_floating_point_operations(args, batch_size) - num_flop_per_sec_lm = (num_flop_lm / elapsed_time_per_iteration) - tflops_lm = (num_flop_per_sec_lm / (10 ** 12)) - tflops_lm_per_gpu = (tflops_lm / args.world_size) - wandb_metrics |= { - "throughput/iteration-time": elapsed_time_per_iteration, # 1000 ms / s - "throughput/samples_per_sec": samples_per_sec, - "throughput/samples_per_sec_per_replica": samples_per_sec_per_replica, - "throughput/tokens_per_sec": tokens_per_sec, - "throughput/tokens_per_sec_per_replica": tokens_per_sec_per_replica, - "throughput/tokens_per_gpu_per_sec": tokens_per_gpu_per_second, - "throughput/tokens_per_gpu_per_sec_per_replica": tokens_per_gpu_per_second_per_replica, - "throughput/tflops": tflops, - "throughput/tflops-new": num_flop_lm / elapsed_time_per_iteration, - "throughput/tflops-lm": tflops_lm_per_gpu, - "throughput/approx_params_in_billions": approx_parameters_in_billions, - "throughput/elapsed_ms_per_iteration": elapsed_time_per_iteration, - "throughput/iteration": iteration, - } - if loss_dict is not None: - wandb_metrics |= { - "loss/iteration": iteration, - **{f"loss/{k}": v for k, v in loss_dict.items()}, - } - if writer and args.log_timers_to_tensorboard: - writer.add_scalar( - "iteration-time/iteration-time", elapsed_time_per_iteration, iteration - ) - writer.add_scalar( - "iteration-time/iteration-time vs samples", - elapsed_time_per_iteration, - args.consumed_train_samples, - ) - writer.add_scalar( - "iteration-time/iteration-time vs tokens", - elapsed_time_per_iteration, - args.consumed_train_tokens, - ) - # metrics_to_log = { - # 'iteration': iteration, - # 'train_iters': args.train_iters, - # 'consumed_samples': args.consumed_train_samples, - # 'consumed_tokens': args.consumed_tokens, - # } - log_string = f" iteration={iteration:8d}/{args.train_iters:8d} |" - # .format( iteration, args.train_iters) - log_string += ( - f" consumed_samples={args.consumed_train_samples:12d} |" - # .format(args.consumed_train_samples) - ) - log_string += f" consumed_tokens={args.consumed_train_tokens:12d} |" - # .format( args.consumed_train_tokens) - log_string += ( - " elapsed_time_per_iteration_ms=" - f"{elapsed_time_per_iteration * 1000.0:.1f} |" - # .format( elapsed_time_per_iteration * 1000.0) - ) - log_string += f" learning_rate={learning_rate:.6g} |" - log_string += f" global_batch_size={batch_size:5d} |" - # if wandb is not None and getattr(wandb, 'run', None) is not None: - wandb_metrics |= { - "training/iteration": iteration, - "training/iteration_time": elapsed_time_per_iteration, - "training/iteration_time_vs_tokens": ( - elapsed_time_per_iteration / args.consumed_train_tokens - ), - "training/iteration_time_vs_samples": ( - (elapsed_time_per_iteration / args.consumed_train_samples), - ), - "training/consumed_samples": args.consumed_train_samples, - "training/consumed_tokens": args.consumed_train_tokens, - } - for key in total_loss_dict: - if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: - avg = total_loss_dict[key].item() / float( - max(1, total_loss_dict[advanced_iters_key]) - ) - if avg > 0.0: - log_string += " {}={:.6f} |".format(key, avg) - total_loss_dict[key] = get_accelerator().FloatTensor([0.0]) - if loss_scale is not None: - log_string += " loss_scale={:.1f} |".format(loss_scale) - wandb_metrics |= {"loss/loss_scale": loss_scale} - if grad_norm is not None: - log_string += " grad_norm={:.3f} |".format(grad_norm) - wandb_metrics |= {"loss/grad_norm": grad_norm} - if num_zeros_in_grad is not None: - log_string += " num_zeros={:.1f} |".format(num_zeros_in_grad) - wandb_metrics |= {"loss/num_zeros_in_grad": num_zeros_in_grad} - if params_norm is not None: - log_string += " params_norm={:.3f} |".format(params_norm) - wandb_metrics |= {"loss/params_norm": params_norm} - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - log_string += " curriculum_seqlen={:5d} |".format(args.curriculum_seqlen) - if args.random_ltd: - log_string += " random_ltd reserved_length={:5d} |".format( - args.random_ltd_reserved_length + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0]: + losses_reduced_for_key = [x[key] for x in losses_reduced] + loss_reduced[key] = sum(losses_reduced_for_key) / len( + losses_reduced_for_key ) - # log_string += " | ".join([ - # f"{seq_len=:5d} ", - # f"{}" - # f"number_of_skipped_iterations={:3d}", - # - # ]) - log_string += " actual_seqlen={:5d} |".format(seq_len) - log_string += " number_of_skipped_iterations={:3d} |".format( - total_loss_dict[skipped_iters_key] - ) - log_string += " number_of_nan_iterations={:3d} |".format( - total_loss_dict[nan_iters_key] - ) - log_string += " samples_per_second={:.3f} |".format(samples_per_sec) - log_string += " tokens_per_gpu_per_second_tgs={:.3f} |".format( - tokens_per_gpu_per_second - ) - log_string += " [LM]TFLOPs={:.2f} |".format(tflops_lm_per_gpu) - log_string += " [DS]TFLOPs={:.2f} |".format(tflops) - total_loss_dict[advanced_iters_key] = 0 - total_loss_dict[skipped_iters_key] = 0 - total_loss_dict[nan_iters_key] = 0 - # print_rank_last(log_string) - log.info(log_string) - if report_memory_flag and learning_rate > 0.0: - # Report memory after optimizer state has been initialized. - report_memory("(after {} iterations)".format(iteration)) - report_memory_flag = False - if wandb is not None and getattr(wandb, "run", None) is not None: - wandb_metrics |= { - "training/skiped_iterations": total_loss_dict[skipped_iters_key] - } - wandb_metrics |= {"training/nan_iterations": total_loss_dict[nan_iters_key]} - wandb.log(wandb_metrics) - if timers is not None: - timers.log(timers_to_log, normalizer=args.log_interval) - - return report_memory_flag + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad @dlp.log @@ -1626,23 +994,29 @@ def train( """Train the model function.""" args = get_args() timers = get_timers() - assert args is not None - assert timers is not None + accelerator = get_accelerator() + assert args is not None and timers is not None and accelerator is not None # Write args to tensorboard write_args_to_tensorboard() + assert accelerator is not None + setup_profiler(args, accelerator.device_name()) if args.random_ltd: # random-ltd requires different randomness on each rank import random + random.seed(args.seed + torch.distributed.get_rank()) # Turn on training mode which enables dropout. for model_module in model: model_module.train() + grad_norm = None # Tracking loss. total_loss_dict = {} + loss_dict = {"skipped_iter": 0} # Iterations. iteration = args.iteration # Translate args to core configuration config = core_transformer_config_from_args(args) + num_skipped_iters = 0 if not args.deepspeed: config.grad_scale_func = optimizer.scale_loss config.timers = timers @@ -1654,9 +1028,23 @@ def train( args.random_ltd_layer_num = model[ 0 ].random_ltd_scheduler.get_random_ltd_layer_num() + ranges_to_skip = None + if args.train_range_to_skip is not None: + assert ( + len(args.train_range_to_skip) % 2 == 0 + ), f"""Expected --train-range-to-skip to have an even number of values. + Received: {len(args.train_range_to_skip)} + """ + ranges_to_skip = list( + zip( + args.train_range_to_skip[::2], + args.train_range_to_skip[1::2], + ) + ) while iteration < args.train_iters and ( args.train_tokens is None or args.consumed_train_tokens < args.train_tokens ): + trigger(on_step_begin) update_num_microbatches(args.consumed_train_samples) if args.deepspeed: # inform deepspeed of any batch size changes @@ -1675,18 +1063,61 @@ def train( update_rotary_pos_emb(curriculum_seqlen) args.curriculum_seqlen = curriculum_seqlen args.curr_iteration = iteration - if os.getenv("TORCH_PROFILER_ENABLE") == "2": - from torch.profiler import profile, record_function, ProfilerActivity - try: - activities = [ - ProfilerActivity.CPU, - ProfilerActivity.CUDA, - ProfilerActivity.XPU, - ] - except Exception: - log.warning("TORCH PROFILER WARNING: XPU is not supported") - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - with profile(activities=activities) as prof: + if ranges_to_skip is not None and any( + [i <= (iteration + 1) <= j for (i, j) in ranges_to_skip] + ): + log.info(f"Caught {iteration + 1} in 'ranges_to_skip', skipping!") + skipped_iter = 1 + num_skipped_iters += 1 + num_zeros_in_grad = None + gas = args.deepspeed_config_dict["gradient_accumulation_steps"] + for microstep in range(gas): + _batch = next(train_data_iterator) + _tokens = _batch["text"] + if ( + iteration < 10 + and os.environ.get("DUMP_SKIPPED_ITERS", None) + and RANK == 0 + ): + log.info(f"{_tokens.shape}, {len(train_data_iterator)=}") + log.info( + f"{iteration=} [{microstep}/{gas}]: ({_tokens.shape})\n{_tokens[:10]=}" + ) + + increment = ( + get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + ) + model[0].skipped_steps += 1 + model[0].global_steps += 1 + model[0].micro_steps += 1 + model[0].global_samples += model[0].train_batch_size() + opt_param_scheduler.step(increment=increment) + else: + if os.getenv("TORCH_PROFILER_ENABLE") == "2": + from torch.profiler import profile, ProfilerActivity + + try: + activities = [ + ProfilerActivity.CPU, + ProfilerActivity.CUDA, + ProfilerActivity.XPU, # type:ignore + ] + except Exception: + log.warning("TORCH PROFILER WARNING: XPU is not supported") + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + with profile(activities=activities) as prof: + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( + forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config, + ) + prof.export_chrome_trace( + f"{args.trace_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}-step{iteration}.json" + ) + else: loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( forward_step_func, train_data_iterator, @@ -1695,18 +1126,6 @@ def train( opt_param_scheduler, config, ) - prof.export_chrome_trace( - f"{args.trace_dir}/torch-trace-{RANK}-of-{WORLD_SIZE}-step{iteration}.json" - ) - else: - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( - forward_step_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config, - ) iteration += 1 args.iteration = iteration new_samples = ( @@ -1792,7 +1211,8 @@ def train( saved_checkpoint = False if args.exit_signal_handler: signal_handler = get_signal_handler() - if any(signal_handler.signals_received()): + # if any(signal_handler.signals_received()): + if signal_handler is not None and any(signal_handler.signals_received()): save_checkpoint_and_time( iteration, model, optimizer, opt_param_scheduler ) @@ -1804,9 +1224,7 @@ def train( # Exiting based on duration if args.exit_duration_in_mins: train_time = (time.time() - _TRAIN_START_TIME) / 60.0 - done_cuda = get_accelerator().IntTensor( - [train_time > args.exit_duration_in_mins] - ) + done_cuda = accelerator.IntTensor([train_time > args.exit_duration_in_mins]) torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) done = done_cuda.item() if done: @@ -1825,6 +1243,19 @@ def train( torch.distributed.barrier() print_datetime("exiting program at iteration {}".format(iteration)) sys.exit() + trigger(on_step_end) + # Exiting based on kill switch file + if found_kill_switch(): + if args.save and not saved_checkpoint: + save_checkpoint_and_time( + iteration, model, optimizer, opt_param_scheduler + ) + torch.distributed.barrier() + print_datetime( + f"Detected kill switch at {args.kill_switch_file}, " + f"iteration={iteration}. Exiting" + ) + sys.exit() return iteration @@ -1839,7 +1270,8 @@ def evaluate( ): """Evaluation.""" args = get_args() - assert args is not None + accelerator = get_accelerator() + assert args is not None and accelerator is not None if args.vision_pretraining and args.vision_pretraining_type == "dino": compute_feature_bank(model) @@ -1861,6 +1293,10 @@ def evaluate( total_loss_dict = {} + num_microbatches = get_num_microbatches() + assert num_microbatches is not None + forward_backward_func = get_forward_backward_func() + with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: @@ -1868,20 +1304,19 @@ def evaluate( if verbose and iteration % args.log_interval == 0: log.info("Evaluating iter {}/{}".format(iteration, args.eval_iters)) - forward_backward_func = get_forward_backward_func() # Don't care about timing during evaluation config.timers = None if args.deepspeed and args.ds_pipeline_enabled: # DeepSpeed uses eval_batch() and already aggregates losses. assert isinstance(model, list) and len(model) == 1 loss = model[0].eval_batch(data_iterator) - loss_dicts = [{"lm loss": loss}] * get_num_microbatches() + loss_dicts = [{"lm loss": loss}] * num_microbatches else: loss_dicts = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1891,7 +1326,7 @@ def evaluate( # Empty unused memory if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() + accelerator.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Reduce across processes. @@ -1899,16 +1334,14 @@ def evaluate( for key in loss_dict: if "moe" not in key: total_loss_dict[key] = ( - total_loss_dict.get( - key, get_accelerator().FloatTensor([0.0]) - ) + total_loss_dict.get(key, accelerator.FloatTensor([0.0])) + loss_dict[key] ) args.consumed_valid_samples += ( mpu.get_data_parallel_world_size() * args.micro_batch_size - * get_num_microbatches() + * num_microbatches ) collected_non_loss_data = None if process_non_loss_data_func is not None and is_last_rank(): @@ -1916,7 +1349,7 @@ def evaluate( forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, - num_microbatches=get_num_microbatches(), + num_microbatches=num_microbatches, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, @@ -1929,7 +1362,7 @@ def evaluate( model_module.train() for key in total_loss_dict: - total_loss_dict[key] /= args.eval_iters * get_num_microbatches() + total_loss_dict[key] /= args.eval_iters * num_microbatches if args.curriculum_learning_legacy and not args.no_pipeline_parallel: # roll back to actual curriculum seqlen at the end of eval. @@ -1959,6 +1392,7 @@ def evaluate_and_print_results( ): """Helper function to evaluate and dump results on screen.""" args = get_args() + assert args is not None if write_to_tensorboard: writer = get_tensorboard_writer() else: @@ -1978,7 +1412,7 @@ def evaluate_and_print_results( ppl = math.exp(min(20, total_loss_dict[key].item())) string += f"{key} PPL={ppl:.6f}" # string += '{} PPL={:.6f} | '.format(key, ppl) - if writer and is_last_rank(): + if writer is not None and is_last_rank(): data_type = "test" if test else "validation" writer.add_scalar( f"lm-loss-validation/{key} {data_type}", @@ -2033,6 +1467,7 @@ def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): args = get_args() # Number of train/valid/test samples. + assert args is not None if args.train_samples: train_samples = args.train_samples else: @@ -2058,7 +1493,8 @@ def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider): """Build pretraining data loaders.""" args = get_args() - assert args is not None + accelerator = get_accelerator() + assert args is not None and accelerator is not None (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) log.info("> building train, validation, and test datasets ...") # Backward compatibility, assume fixed batch size. @@ -2101,11 +1537,9 @@ def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. - flags = get_accelerator().LongTensor( - [int(do_train), int(do_valid), int(do_test)] - ) + flags = accelerator.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: - flags = get_accelerator().LongTensor([0, 0, 0]) + flags = accelerator.LongTensor([0, 0, 0]) # Broadcast num tokens. if ds_sequence_parallel: torch.distributed.broadcast( @@ -2131,6 +1565,7 @@ def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provid """Build pretraining data iterators.""" args = get_args() + assert args is not None # Build loaders. train_dataloader, valid_dataloader, test_dataloader = ( diff --git a/megatron/training_log.py b/megatron/training_log.py new file mode 100644 index 0000000000..3eb96c392d --- /dev/null +++ b/megatron/training_log.py @@ -0,0 +1,669 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +""" +training_log.py +""" + +import logging +import os + +from deepspeed import get_accelerator +import ezpz as ez +import torch + +from megatron.core import mpu +from megatron.global_vars import ( + get_args, + get_num_microbatches, + get_tensorboard_writer, + get_timers, +) +from megatron.utils import ( + Profile, + is_last_rank, + report_memory, + throughput_calculator, + num_floating_point_operations, +) + + +RANK: int = ez.get_rank() +WORLD_SIZE: int = ez.get_world_size() +DEVICE_TYPE: str = ez.dist.get_torch_device_type() +DEVICE: torch.device = torch.device(DEVICE_TYPE) + +log: logging.Logger = logging.getLogger(__name__) +LOG_LEVEL: str = str(os.environ.get("LOG_LEVEL", "INFO")).upper() +log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL") + +try: + import wandb +except (ImportError, ModuleNotFoundError): + wandb = None + + +dlp = Profile("TRAINING_LOG") + + +@dlp.log +def training_log( + loss_dict, + total_loss_dict, + learning_rate, + iteration, + loss_scale, + report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, + model=None, + optimizer=None, +): + """Log training information such as losses, timing, ....""" + args = get_args() + accelerator = get_accelerator() + timers = get_timers() + writer = get_tensorboard_writer() + assert args is not None and timers is not None and accelerator is not None + wandb_metrics = {} + # Advanced, skipped, and Nan iterations. + advanced_iters_key = "advanced iterations" + skipped_iters_key = "skipped iterations" + nan_iters_key = "nan iterations" + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = ( + total_loss_dict.get(advanced_iters_key, 0) + 1 + ) + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = ( + total_loss_dict.get(skipped_iters_key, 0) + skipped_iter + ) + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = ( + total_loss_dict.get(key, accelerator.FloatTensor([0.0])) + + loss_dict[key] + ) + else: + try: + value = loss_dict[key].float().sum().item() + except AttributeError: + value = loss_dict[key] + is_nan = value == float("inf") or value == -float("inf") or value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int( + got_nan + ) + + # Logging. + timers_to_log = [ + "forward-backward", + "forward-compute", + "backward-compute", + "batch-generator", + "forward-recv", + "forward-send", + "backward-recv", + "backward-send", + "forward-send-forward-recv", + "forward-send-backward-recv", + "backward-send-forward-recv", + "backward-send-backward-recv", + "forward-backward-send-forward-backward-recv", + "layernorm-grads-all-reduce", + "embedding-grads-all-reduce", + "grads-all-reduce", + "grads-reduce-scatter", + "params-all-gather", + "optimizer-copy-to-main-grad", + "optimizer-unscale-and-check-inf", + "optimizer-clip-main-grad", + "optimizer-count-zeros", + "optimizer-inner-step", + "optimizer-copy-main-to-model-params", + "optimizer", + ] + + # Calculate batch size. + batch_size = ( + args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + ) + total_iterations = ( + total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] + ) + + # Tensorboard values. + # Timer requires all the ranks to call. + if args.log_timers_to_tensorboard and ( + iteration % args.tensorboard_log_interval == 0 + ): + timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) + if writer and (iteration % args.tensorboard_log_interval == 0): + writer.add_scalar( + "steps-vs-samples/y=steps,x=samples", iteration, args.consumed_train_samples + ) + writer.add_scalar( + "steps-vs-samples/y=samples,x=steps", args.consumed_train_samples, iteration + ) + writer.add_scalar( + "steps-vs-tokens/y=steps,x=tokens", iteration, args.consumed_train_tokens + ) + writer.add_scalar( + "steps-vs-tokens/y=tokens,x=steps", args.consumed_train_tokens, iteration + ) + if args.log_learning_rate_to_tensorboard: + wandb_metrics |= { + "learning-rate/iteration": iteration, + "learning-rate/learning-rate": learning_rate, + } + writer.add_scalar("learning-rate/learning-rate", learning_rate, iteration) + writer.add_scalar( + "learning-rate/learning-rate vs samples", + learning_rate, + args.consumed_train_samples, + ) + writer.add_scalar( + "learning-rate/learning-rate vs tokens", + learning_rate, + args.consumed_train_tokens, + ) + if args.log_batch_size_to_tensorboard: + writer.add_scalar("batch-size/batch-size", batch_size, iteration) + writer.add_scalar( + "batch-size/batch-size vs samples", + batch_size, + args.consumed_train_samples, + ) + writer.add_scalar( + "batch-size/batch-size vs tokens", + batch_size, + args.consumed_train_tokens, + ) + wandb_metrics |= { + "lm-loss-training/iteration": iteration, + "lm-loss-training/consumed_train_tokens": args.consumed_train_tokens, + } + for key in loss_dict: + wandb_metrics |= {f"lm-loss-training/{key}": loss_dict[key]} + writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) + writer.add_scalar( + f"lm-loss-training/{key}" + " vs samples", + loss_dict[key], + args.consumed_train_samples, + ) + writer.add_scalar( + f"lm-loss-training/{key}" + " vs tokens", + loss_dict[key], + args.consumed_train_tokens, + ) + if args.fp16 and loss_scale and args.log_loss_scale_to_tensorboard: + writer.add_scalar("loss-scale/loss-scale", loss_scale, iteration) + writer.add_scalar( + "loss-scale/loss-scale vs samples", + loss_scale, + args.consumed_train_samples, + ) + writer.add_scalar( + "loss-scale/loss-scale vs tokens", + loss_scale, + args.consumed_train_tokens, + ) + if args.log_world_size_to_tensorboard: + writer.add_scalar("world-size/world-size", args.world_size, iteration) + writer.add_scalar( + "world-size/world-size vs samples", + args.world_size, + args.consumed_train_samples, + ) + writer.add_scalar( + "world-size/world-size vs tokens", + args.world_size, + args.consumed_train_tokens, + ) + if grad_norm is not None: + wandb_metrics |= {"training/grad-norm": grad_norm} + writer.add_scalar("grad-norm/grad-norm", grad_norm, iteration) + writer.add_scalar( + "grad-norm/grad-norm vs samples", grad_norm, args.consumed_train_samples + ) + writer.add_scalar( + "grad-norm/grad-norm vs tokens", grad_norm, args.consumed_train_tokens + ) + if num_zeros_in_grad is not None: + wandb_metrics |= {"training/num-zeros": num_zeros_in_grad} + writer.add_scalar("num-zeros/num-zeros", num_zeros_in_grad, iteration) + writer.add_scalar( + "num-zeros/num-zeros vs samples", + num_zeros_in_grad, + args.consumed_train_samples, + ) + writer.add_scalar( + "num-zeros/num-zeros vs tokens", + num_zeros_in_grad, + args.consumed_train_tokens, + ) + if params_norm is not None: + wandb_metrics |= {"training/params-norm": params_norm} + writer.add_scalar("params-norm/params-norm", params_norm, iteration) + writer.add_scalar( + "params-norm/params-norm vs samples", + params_norm, + args.consumed_train_samples, + ) + writer.add_scalar( + "params-norm/params-norm vs tokens", + params_norm, + args.consumed_train_tokens, + ) + if hasattr(args, "actual_seq_length"): + writer.add_scalar( + "seqlen/actual_seq_length", args.actual_seq_length, iteration + ) + writer.add_scalar( + "seqlen/actual_seq_length vs samples", + args.actual_seq_length, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/actual_seq_length vs tokens", + args.actual_seq_length, + args.consumed_train_tokens, + ) + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + writer.add_scalar( + "seqlen/curriculum_seqlen", args.curriculum_seqlen, iteration + ) + writer.add_scalar( + "seqlen/curriculum_seqlen vs samples", + args.curriculum_seqlen, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/curriculum_seqlen vs tokens", + args.curriculum_seqlen, + args.consumed_train_tokens, + ) + if args.random_ltd: + writer.add_scalar( + "seqlen/random_ltd_reserved_length", + args.random_ltd_reserved_length, + iteration, + ) + writer.add_scalar( + "seqlen/random_ltd_reserved_length vs samples", + args.random_ltd_reserved_length, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/random_ltd_reserved_length vs tokens", + args.random_ltd_reserved_length, + args.consumed_train_tokens, + ) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + "mem-reserved-bytes", + mem_stats["reserved_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-bytes", + mem_stats["allocated_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-count", + mem_stats["allocation.all.current"], + iteration, + ) + if iteration % args.tensorboard_log_interval == 0: + # This logging write various optimizer states to tensorboard. This + # feature may consume extra GPU memory thus is set at false by default. + if args.log_optimizer_states_to_tensorboard and optimizer is not None: + opt_stats = [0.0] * 8 + opt_stats_2 = [0.0] * 4 + for _, group in enumerate(optimizer.param_groups): + for _, param in enumerate(group["params"]): + state_param = getattr(optimizer, "state", None) + if state_param is not None: + exp_avg_sq = state_param.get("exp_avg_sq", torch.tensor(0.0)) + exp_avg = state_param.get("exp_avg", torch.tensor(0.0)) + opt_stats[0] += (torch.norm(exp_avg_sq).item()) ** 2 + opt_stats[1] += (torch.norm(exp_avg_sq.sqrt()).item()) ** 2 + opt_stats[2] += (torch.norm(exp_avg).item()) ** 2 + opt_stats[3] += (torch.norm(param).item()) ** 2 + opt_stats[4] += torch.norm(exp_avg_sq, p=1).item() + opt_stats[5] += torch.norm(exp_avg_sq.sqrt(), p=1).item() + opt_stats[6] += torch.norm(exp_avg, p=1).item() + opt_stats[7] += torch.norm(param, p=1).item() + opt_stats_2[0] = max( + opt_stats_2[0], + abs(exp_avg_sq.max().item()), + abs(exp_avg_sq.min().item()), + ) + opt_stats_2[1] = max( + opt_stats_2[1], exp_avg_sq.sqrt().abs_().max().item() + ) + opt_stats_2[2] = max( + opt_stats_2[2], + abs(exp_avg.max().item()), + abs(exp_avg.min().item()), + ) + opt_stats_2[3] = max( + opt_stats_2[3], + abs(param.max().item()), + abs(param.min().item()), + ) + # print('step {} rank {} before sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) + if args.zero_stage > 0: + # ZeRO partiions optimizer states + # opt_stats = opt_stats.clone().detach() + # opt_stats = get_accelerator().FloatTensor + opt_stats = accelerator.FloatTensor(opt_stats) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_sequence_data_parallel_group() + ) + # opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + # opt_stats_2 = opt_stats_2.clone().detach() + opt_stats_2 = accelerator.FloatTensor(opt_stats_2) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_sequence_data_parallel_group(), + ) + + if args.tensor_model_parallel_size > 1: + # opt_stats = opt_stats.clone().detach() + opt_stats = accelerator.FloatTensor(opt_stats) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_tensor_model_parallel_group() + ) + # opt_stats_2 = opt_stats_2.clone().detach() + opt_stats_2 = accelerator.FloatTensor(opt_stats_2) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_tensor_model_parallel_group(), + ) + + if args.pipeline_model_parallel_size > 1: + # opt_stats = opt_stats.clone().detach() + opt_stats = accelerator.FloatTensor(opt_stats) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_pipeline_model_parallel_group() + ) + # opt_stats_2 = opt_stats_2.clone().detach() + opt_stats_2 = accelerator.FloatTensor(opt_stats_2) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_pipeline_model_parallel_group(), + ) + wandb_metrics |= { + "optimizer/learning_rate": learning_rate, + "optimizer/iteration": args.iteration, + "optimizer/consumed_train_tokens": args.consumed_train_tokens, + "optimizer/variance_l2": opt_stats[0] ** 0.5, + "optimizer/variance_sqrt_l2": opt_stats[1] ** 0.5, + "optimizer/momentum_l2": opt_stats[2] ** 0.5, + "optimizer/weight_l2": opt_stats[3] ** 0.5, + "optimizer/variance_l1": opt_stats[4], + "optimizer/variance_sqrt_l1": opt_stats[5], + "optimizer/momentum_l1": opt_stats[6], + "optimizer/weight_l1": opt_stats[7], + "optimizer/variance_abs_max": opt_stats_2[0], + "optimizer/variance_sqrt_abs_max": opt_stats_2[1], + "optimizer/momentum_abs_max": opt_stats_2[2], + "optimizer/weight_abs_max": opt_stats_2[3], + } + # print('step {} rank {} after sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) + if writer and is_last_rank(): + writer.add_scalar( + "optimizer/variance_l2 vs tokens", + opt_stats[0] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_l2 vs tokens", + opt_stats[1] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_l2 vs tokens", + opt_stats[2] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_l2 vs tokens", + opt_stats[3] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_l1 vs tokens", + opt_stats[4], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_l1 vs tokens", + opt_stats[5], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_l1 vs tokens", + opt_stats[6], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_l1 vs tokens", + opt_stats[7], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_abs_max vs tokens", + opt_stats_2[0], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_abs_max vs tokens", + opt_stats_2[1], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_abs_max vs tokens", + opt_stats_2[2], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_abs_max vs tokens", + opt_stats_2[3], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_l2", opt_stats[0] ** 0.5, iteration + ) + writer.add_scalar( + "optimizer/variance_sqrt_l2", opt_stats[1] ** 0.5, iteration + ) + writer.add_scalar( + "optimizer/momentum_l2", opt_stats[2] ** 0.5, iteration + ) + writer.add_scalar("optimizer/weight_l2", opt_stats[3] ** 0.5, iteration) + writer.add_scalar("optimizer/variance_l1", opt_stats[4], iteration) + writer.add_scalar("optimizer/variance_sqrt_l1", opt_stats[5], iteration) + writer.add_scalar("optimizer/momentum_l1", opt_stats[6], iteration) + writer.add_scalar("optimizer/weight_l1", opt_stats[7], iteration) + writer.add_scalar( + "optimizer/variance_abs_max", opt_stats_2[0], iteration + ) + writer.add_scalar( + "optimizer/variance_sqrt_abs_max", opt_stats_2[1], iteration + ) + writer.add_scalar( + "optimizer/momentum_abs_max", opt_stats_2[2], iteration + ) + writer.add_scalar("optimizer/weight_abs_max", opt_stats_2[3], iteration) + + assert args is not None + assert timers is not None + if iteration % args.log_interval == 0: + elapsed_time = timers("interval-time").elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + seq_len = args.seq_length + if hasattr(args, "actual_seq_length"): + seq_len = args.actual_seq_length + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + model, args, elapsed_time, total_iterations + ) + samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size + tokens_per_sec = samples_per_sec * seq_len + tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size + tokens_per_gpu_per_second = tokens_per_sec / args.world_size + tokens_per_gpu_per_second_per_replica = ( + tokens_per_gpu_per_second / args.data_parallel_size + ) + # NOTE: [2024-06-19] + # Updated to use (more accurate) calculation according to + # `num_floating_point_operations` from NVIDIA/Megatron-LM + num_flop_lm = num_floating_point_operations(args, batch_size) + num_flop_per_sec_lm = num_flop_lm / elapsed_time_per_iteration + tflops_lm = num_flop_per_sec_lm / (10**12) + tflops_lm_per_gpu = tflops_lm / args.world_size + wandb_metrics |= { + "throughput/iteration-time": elapsed_time_per_iteration, # 1000 ms / s + "throughput/samples_per_sec": samples_per_sec, + "throughput/samples_per_sec_per_replica": samples_per_sec_per_replica, + "throughput/tokens_per_sec": tokens_per_sec, + "throughput/tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "throughput/tokens_per_gpu_per_sec": tokens_per_gpu_per_second, + "throughput/tokens_per_gpu_per_sec_per_replica": tokens_per_gpu_per_second_per_replica, + "throughput/tflops": tflops, + "throughput/tflops-new": num_flop_lm / elapsed_time_per_iteration, + "throughput/tflops-lm": tflops_lm_per_gpu, + "throughput/approx_params_in_billions": approx_parameters_in_billions, + "throughput/elapsed_ms_per_iteration": elapsed_time_per_iteration, + "throughput/iteration": iteration, + } + if loss_dict is not None: + wandb_metrics |= { + "loss/iteration": iteration, + **{f"loss/{k}": v for k, v in loss_dict.items()}, + } + if writer and args.log_timers_to_tensorboard: + writer.add_scalar( + "iteration-time/iteration-time", elapsed_time_per_iteration, iteration + ) + writer.add_scalar( + "iteration-time/iteration-time vs samples", + elapsed_time_per_iteration, + args.consumed_train_samples, + ) + writer.add_scalar( + "iteration-time/iteration-time vs tokens", + elapsed_time_per_iteration, + args.consumed_train_tokens, + ) + # metrics_to_log = { + # 'iteration': iteration, + # 'train_iters': args.train_iters, + # 'consumed_samples': args.consumed_train_samples, + # 'consumed_tokens': args.consumed_tokens, + # } + log_string = f" iteration={iteration:8d}/{args.train_iters:8d} |" + # .format( iteration, args.train_iters) + log_string += ( + f" consumed_samples={args.consumed_train_samples:12d} |" + # .format(args.consumed_train_samples) + ) + log_string += f" consumed_tokens={args.consumed_train_tokens:12d} |" + # .format( args.consumed_train_tokens) + log_string += ( + " elapsed_time_per_iteration_ms=" + f"{elapsed_time_per_iteration * 1000.0:.1f} |" + # .format( elapsed_time_per_iteration * 1000.0) + ) + log_string += f" learning_rate={learning_rate:.6g} |" + log_string += f" global_batch_size={batch_size:5d} |" + # if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb_metrics |= { + "training/iteration": iteration, + "training/iteration_time": elapsed_time_per_iteration, + "training/iteration_time_vs_tokens": ( + elapsed_time_per_iteration / args.consumed_train_tokens + ), + "training/iteration_time_vs_samples": ( + (elapsed_time_per_iteration / args.consumed_train_samples), + ), + "training/consumed_samples": args.consumed_train_samples, + "training/consumed_tokens": args.consumed_train_tokens, + } + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: + avg = total_loss_dict[key].item() / float( + max(1, total_loss_dict[advanced_iters_key]) + ) + if avg > 0.0: + log_string += " {}={:.6f} |".format(key, avg) + total_loss_dict[key] = accelerator.FloatTensor([0.0]) + if loss_scale is not None: + log_string += " loss_scale={:.1f} |".format(loss_scale) + wandb_metrics |= {"loss/loss_scale": loss_scale} + if grad_norm is not None: + log_string += " grad_norm={:.3f} |".format(grad_norm) + wandb_metrics |= {"loss/grad_norm": grad_norm} + if num_zeros_in_grad is not None: + log_string += " num_zeros={:.1f} |".format(num_zeros_in_grad) + wandb_metrics |= {"loss/num_zeros_in_grad": num_zeros_in_grad} + if params_norm is not None: + log_string += " params_norm={:.3f} |".format(params_norm) + wandb_metrics |= {"loss/params_norm": params_norm} + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + log_string += " curriculum_seqlen={:5d} |".format(args.curriculum_seqlen) + if args.random_ltd: + log_string += " random_ltd reserved_length={:5d} |".format( + args.random_ltd_reserved_length + ) + # log_string += " | ".join([ + # f"{seq_len=:5d} ", + # f"{}" + # f"number_of_skipped_iterations={:3d}", + # + # ]) + log_string += " actual_seqlen={:5d} |".format(seq_len) + log_string += " number_of_skipped_iterations={:3d} |".format( + total_loss_dict[skipped_iters_key] + ) + log_string += " number_of_nan_iterations={:3d} |".format( + total_loss_dict[nan_iters_key] + ) + log_string += " samples_per_second={:.3f} |".format(samples_per_sec) + log_string += " tokens_per_gpu_per_second_tgs={:.3f} |".format( + tokens_per_gpu_per_second + ) + log_string += " [LM]TFLOPs={:.2f} |".format(tflops_lm_per_gpu) + log_string += " [DS]TFLOPs={:.2f} |".format(tflops) + if wandb is not None and getattr(wandb, "run", None) is not None: + wandb_metrics |= { + "training/skiped_iterations": total_loss_dict[skipped_iters_key] + } + wandb_metrics |= {"training/nan_iterations": total_loss_dict[nan_iters_key]} + wandb.log(wandb_metrics) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + # print_rank_last(log_string) + log.info(log_string) + if report_memory_flag and learning_rate > 0.0: + # Report memory after optimizer state has been initialized. + report_memory("(after {} iterations)".format(iteration)) + report_memory_flag = False + if timers is not None: + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag diff --git a/megatron/training_log_alcf.py b/megatron/training_log_alcf.py new file mode 100644 index 0000000000..dcd872971d --- /dev/null +++ b/megatron/training_log_alcf.py @@ -0,0 +1,725 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain utilities.""" + +from enum import Enum + +# from deepspeed.accelerator import get_accelerator +# from deepspeed.compression.compress import redundancy_clean +import torch +import os +import logging + +from megatron import get_args +from megatron import get_timers +from megatron import get_tensorboard_writer +from megatron import get_wandb_writer +from megatron import get_num_microbatches +from megatron.core import mpu + +# from megatron import is_rank_0, print_rank_0 +# from megatron import print_rank_last +# from megatron.arguments import core_transformer_config_from_args +# from megatron.checkpointing import load_checkpoint +# from megatron.checkpointing import save_checkpoint +# from megatron.core import mpu, tensor_parallel +# from megatron.core.enums import ModelType +# from megatron.core.pipeline_parallel import get_forward_backward_func +# from megatron.data.data_samplers import build_pretraining_data_loader +# from megatron.initialize import initialize_megatron +# from megatron.initialize import write_args_to_tensorboard +# from megatron.initialize import set_jit_fusion_options +# from megatron.model import Float16Module +# from megatron.model import GPTModel +# from megatron.model import DistributedDataParallel as LocalDDP +# from megatron.model.transformer import ParallelTransformerLayer +# from megatron.model.vision.knn_monitor import compute_feature_bank +# from megatron.optimizer import get_megatron_optimizer +# from megatron.optimizer_param_scheduler import OptimizerParamScheduler +# from megatron.profiler import on_step_begin, on_step_end, setup_profiler, trigger +# from megatron.utils import check_adlr_autoresume_termination +# from megatron.utils import found_kill_switch, unwrap_model +import ezpz as ez + +# from megatron.utils import calc_params_l2_norm +from megatron.utils import ( + # checkpoint_throughput_calculator, + report_memory, + throughput_calculator, + # update_rotary_pos_emb, +) + +try: + import wandb +except (ImportError, ModuleNotFoundError): + wandb = None +# The earliest we can measure the start time. +# _TRAIN_START_TIME = time.time() + + +log = logging.getLogger(__name__) + + +class InteropLoggingTool(Enum): + TENSORBOARD = 1 + WANDB = 2 + + +RANK: int = ez.get_rank() +LOCAL_RANK: int = ez.get_local_rank() +WORLD_SIZE: int = ez.get_world_size() +DEVICE_TYPE: str = ez.dist.get_torch_device_type() +DEVICE_ID: str = f"{DEVICE_TYPE}:{LOCAL_RANK}" +DEVICE: torch.device = torch.device(DEVICE_TYPE) + +log: logging.Logger = logging.getLogger(__name__) +LOG_LEVEL: str = str(os.environ.get("LOG_LEVEL", "INFO")).upper() +log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL") + + +def num_floating_point_operations(args, batch_size): + # Group Query Attention. + # if not args.group_query_attention: + if not args.num_key_value_heads: + args.num_key_value_heads = args.num_attention_heads + # args.num_query_groups = args.num_attention_heads + # MoE. + # num_experts_routed_to = 1 if args.num_experts is None else args.moe_router_topk + num_experts_routed_to = 1 if args.num_experts is None else args.topk + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + return ( + 12 + * batch_size + * args.seq_length + * args.num_layers + * args.hidden_size + * args.hidden_size + * ( + 1 + + ( + (args.ffn_hidden_size / args.hidden_size) + * num_experts_routed_to + * gated_linear_multiplier + ) + + (args.num_key_value_heads / args.num_attention_heads) + + (args.seq_length / args.hidden_size) + + (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size)) + ) + ) + + +def training_log( + loss_dict, + total_loss_dict, + learning_rate, + iteration, + loss_scale, + report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, + model=None, + optimizer=None, +): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + assert args is not None and timers is not None + wandb_metrics = {} + # Advanced, skipped, and Nan iterations. + advanced_iters_key = "advanced iterations" + skipped_iters_key = "skipped iterations" + nan_iters_key = "nan iterations" + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = ( + total_loss_dict.get(advanced_iters_key, 0) + 1 + ) + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = ( + total_loss_dict.get(skipped_iters_key, 0) + skipped_iter + ) + # Update losses and set nan iterations + got_nan = False + _zero = torch.tensor([0.0]).to(DEVICE) + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = total_loss_dict.get(key, _zero) + loss_dict[key] + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float("inf") or value == -float("inf") or value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int( + got_nan + ) + + # Logging. + timers_to_log = [ + "forward-backward", + "forward-compute", + "backward-compute", + "batch-generator", + "forward-recv", + "forward-send", + "backward-recv", + "backward-send", + "forward-send-forward-recv", + "forward-send-backward-recv", + "backward-send-forward-recv", + "backward-send-backward-recv", + "forward-backward-send-forward-backward-recv", + "layernorm-grads-all-reduce", + "embedding-grads-all-reduce", + "grads-all-reduce", + "grads-reduce-scatter", + "params-all-gather", + "optimizer-copy-to-main-grad", + "optimizer-unscale-and-check-inf", + "optimizer-clip-main-grad", + "optimizer-count-zeros", + "optimizer-inner-step", + "optimizer-copy-main-to-model-params", + "optimizer", + ] + + # Calculate batch size. + batch_size = ( + args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + ) + total_iterations = ( + total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] + ) + + # Tensorboard values. + # Timer requires all the ranks to call. + if args.log_timers_to_tensorboard and ( + iteration % args.tensorboard_log_interval == 0 and writer is not None + ): + timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) + if writer and (iteration % args.tensorboard_log_interval == 0): + writer.add_scalar( + "steps-vs-samples/y=steps,x=samples", iteration, args.consumed_train_samples + ) + writer.add_scalar( + "steps-vs-samples/y=samples,x=steps", args.consumed_train_samples, iteration + ) + writer.add_scalar( + "steps-vs-tokens/y=steps,x=tokens", iteration, args.consumed_train_tokens + ) + writer.add_scalar( + "steps-vs-tokens/y=tokens,x=steps", args.consumed_train_tokens, iteration + ) + if args.log_learning_rate_to_tensorboard: + wandb_metrics |= { + "learning-rate/iteration": iteration, + "learning-rate/learning-rate": learning_rate, + } + writer.add_scalar("learning-rate/learning-rate", learning_rate, iteration) + writer.add_scalar( + "learning-rate/learning-rate vs samples", + learning_rate, + args.consumed_train_samples, + ) + writer.add_scalar( + "learning-rate/learning-rate vs tokens", + learning_rate, + args.consumed_train_tokens, + ) + if args.log_batch_size_to_tensorboard: + writer.add_scalar("batch-size/batch-size", batch_size, iteration) + writer.add_scalar( + "batch-size/batch-size vs samples", + batch_size, + args.consumed_train_samples, + ) + writer.add_scalar( + "batch-size/batch-size vs tokens", + batch_size, + args.consumed_train_tokens, + ) + wandb_metrics |= { + "lm-loss-training/iteration": iteration, + "lm-loss-training/consumed_train_tokens": args.consumed_train_tokens, + } + for key in loss_dict: + wandb_metrics |= {f"lm-loss-training/{key}": loss_dict[key]} + writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) + writer.add_scalar( + f"lm-loss-training/{key}" + " vs samples", + loss_dict[key], + args.consumed_train_samples, + ) + writer.add_scalar( + f"lm-loss-training/{key}" + " vs tokens", + loss_dict[key], + args.consumed_train_tokens, + ) + if args.fp16 and loss_scale and args.log_loss_scale_to_tensorboard: + writer.add_scalar("loss-scale/loss-scale", loss_scale, iteration) + writer.add_scalar( + "loss-scale/loss-scale vs samples", + loss_scale, + args.consumed_train_samples, + ) + writer.add_scalar( + "loss-scale/loss-scale vs tokens", + loss_scale, + args.consumed_train_tokens, + ) + if args.log_world_size_to_tensorboard: + writer.add_scalar("world-size/world-size", args.world_size, iteration) + writer.add_scalar( + "world-size/world-size vs samples", + args.world_size, + args.consumed_train_samples, + ) + writer.add_scalar( + "world-size/world-size vs tokens", + args.world_size, + args.consumed_train_tokens, + ) + if grad_norm is not None: + wandb_metrics |= {"training/grad-norm": grad_norm} + writer.add_scalar("grad-norm/grad-norm", grad_norm, iteration) + writer.add_scalar( + "grad-norm/grad-norm vs samples", grad_norm, args.consumed_train_samples + ) + writer.add_scalar( + "grad-norm/grad-norm vs tokens", grad_norm, args.consumed_train_tokens + ) + if num_zeros_in_grad is not None: + wandb_metrics |= {"training/num-zeros": num_zeros_in_grad} + writer.add_scalar("num-zeros/num-zeros", num_zeros_in_grad, iteration) + writer.add_scalar( + "num-zeros/num-zeros vs samples", + num_zeros_in_grad, + args.consumed_train_samples, + ) + writer.add_scalar( + "num-zeros/num-zeros vs tokens", + num_zeros_in_grad, + args.consumed_train_tokens, + ) + if params_norm is not None: + wandb_metrics |= {"training/params-norm": params_norm} + writer.add_scalar("params-norm/params-norm", params_norm, iteration) + writer.add_scalar( + "params-norm/params-norm vs samples", + params_norm, + args.consumed_train_samples, + ) + writer.add_scalar( + "params-norm/params-norm vs tokens", + params_norm, + args.consumed_train_tokens, + ) + if hasattr(args, "actual_seq_length"): + writer.add_scalar( + "seqlen/actual_seq_length", args.actual_seq_length, iteration + ) + writer.add_scalar( + "seqlen/actual_seq_length vs samples", + args.actual_seq_length, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/actual_seq_length vs tokens", + args.actual_seq_length, + args.consumed_train_tokens, + ) + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + writer.add_scalar( + "seqlen/curriculum_seqlen", args.curriculum_seqlen, iteration + ) + writer.add_scalar( + "seqlen/curriculum_seqlen vs samples", + args.curriculum_seqlen, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/curriculum_seqlen vs tokens", + args.curriculum_seqlen, + args.consumed_train_tokens, + ) + if args.random_ltd: + writer.add_scalar( + "seqlen/random_ltd_reserved_length", + args.random_ltd_reserved_length, + iteration, + ) + writer.add_scalar( + "seqlen/random_ltd_reserved_length vs samples", + args.random_ltd_reserved_length, + args.consumed_train_samples, + ) + writer.add_scalar( + "seqlen/random_ltd_reserved_length vs tokens", + args.random_ltd_reserved_length, + args.consumed_train_tokens, + ) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + "mem-reserved-bytes", + mem_stats["reserved_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-bytes", + mem_stats["allocated_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-count", + mem_stats["allocation.all.current"], + iteration, + ) + if iteration % args.tensorboard_log_interval == 0: + # This logging write various optimizer states to tensorboard. This + # feature may consume extra GPU memory thus is set at false by default. + if args.log_optimizer_states_to_tensorboard and optimizer is not None: + opt_stats = [0.0] * 8 + opt_stats_2 = [0.0] * 4 + for _, group in enumerate(optimizer.param_groups): + for _, param in enumerate(group["params"]): + state_param = getattr(optimizer, "state", None) + if state_param is not None: + exp_avg_sq = state_param.get("exp_avg_sq", torch.tensor(0.0)) + exp_avg = state_param.get("exp_avg", torch.tensor(0.0)) + opt_stats[0] += (torch.norm(exp_avg_sq).item()) ** 2 + opt_stats[1] += (torch.norm(exp_avg_sq.sqrt()).item()) ** 2 + opt_stats[2] += (torch.norm(exp_avg).item()) ** 2 + opt_stats[3] += (torch.norm(param).item()) ** 2 + opt_stats[4] += torch.norm(exp_avg_sq, p=1).item() + opt_stats[5] += torch.norm(exp_avg_sq.sqrt(), p=1).item() + opt_stats[6] += torch.norm(exp_avg, p=1).item() + opt_stats[7] += torch.norm(param, p=1).item() + opt_stats_2[0] = max( + opt_stats_2[0], + abs(exp_avg_sq.max().item()), + abs(exp_avg_sq.min().item()), + ) + opt_stats_2[1] = max( + opt_stats_2[1], exp_avg_sq.sqrt().abs_().max().item() + ) + opt_stats_2[2] = max( + opt_stats_2[2], + abs(exp_avg.max().item()), + abs(exp_avg.min().item()), + ) + opt_stats_2[3] = max( + opt_stats_2[3], + abs(param.max().item()), + abs(param.min().item()), + ) + if args.zero_stage > 0: + # ZeRO partiions optimizer states + # opt_stats = get_accelerator().FloatTensor(opt_stats) + opt_stats = torch.tensor(opt_stats).to(DEVICE) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_sequence_data_parallel_group() + ) + # opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + opt_stats_2 = torch.tensor(opt_stats_2).to(DEVICE) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_sequence_data_parallel_group(), + ) + + if args.tensor_model_parallel_size > 1: + opt_stats = torch.tensor(opt_stats).to(DEVICE) + # opt_stats = get_accelerator().FloatTensor(opt_stats) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_tensor_model_parallel_group() + ) + # opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + opt_stats_2 = torch.tensor(opt_stats_2).to(DEVICE) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_tensor_model_parallel_group(), + ) + + if args.pipeline_model_parallel_size > 1: + # opt_stats = get_accelerator().FloatTensor(opt_stats) + opt_stats = torch.tensor(opt_stats).to(DEVICE) + torch.distributed.all_reduce( + opt_stats, group=mpu.get_pipeline_model_parallel_group() + ) + # opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2) + opt_stats_2 = torch.tensor(opt_stats_2).to(DEVICE) + torch.distributed.all_reduce( + opt_stats_2, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_pipeline_model_parallel_group(), + ) + + wandb_metrics |= { + "optimizer/learning_rate": learning_rate, + "optimizer/iteration": args.iteration, + "optimizer/consumed_train_tokens": args.consumed_train_tokens, + "optimizer/variance_l2": opt_stats[0] ** 0.5, + "optimizer/variance_sqrt_l2": opt_stats[1] ** 0.5, + "optimizer/momentum_l2": opt_stats[2] ** 0.5, + "optimizer/weight_l2": opt_stats[3] ** 0.5, + "optimizer/variance_l1": opt_stats[4], + "optimizer/variance_sqrt_l1": opt_stats[5], + "optimizer/momentum_l1": opt_stats[6], + "optimizer/weight_l1": opt_stats[7], + "optimizer/variance_abs_max": opt_stats_2[0], + "optimizer/variance_sqrt_abs_max": opt_stats_2[1], + "optimizer/momentum_abs_max": opt_stats_2[2], + "optimizer/weight_abs_max": opt_stats_2[3], + } + # print('step {} rank {} after sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats)) + # if writer and is_last_rank(): + if writer is not None and RANK == 0: + writer.add_scalar( + "optimizer/variance_l2 vs tokens", + opt_stats[0] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_l2 vs tokens", + opt_stats[1] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_l2 vs tokens", + opt_stats[2] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_l2 vs tokens", + opt_stats[3] ** 0.5, + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_l1 vs tokens", + opt_stats[4], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_l1 vs tokens", + opt_stats[5], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_l1 vs tokens", + opt_stats[6], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_l1 vs tokens", + opt_stats[7], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_abs_max vs tokens", + opt_stats_2[0], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_sqrt_abs_max vs tokens", + opt_stats_2[1], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/momentum_abs_max vs tokens", + opt_stats_2[2], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/weight_abs_max vs tokens", + opt_stats_2[3], + args.consumed_train_tokens, + ) + writer.add_scalar( + "optimizer/variance_l2", opt_stats[0] ** 0.5, iteration + ) + writer.add_scalar( + "optimizer/variance_sqrt_l2", opt_stats[1] ** 0.5, iteration + ) + writer.add_scalar( + "optimizer/momentum_l2", opt_stats[2] ** 0.5, iteration + ) + writer.add_scalar("optimizer/weight_l2", opt_stats[3] ** 0.5, iteration) + writer.add_scalar("optimizer/variance_l1", opt_stats[4], iteration) + writer.add_scalar("optimizer/variance_sqrt_l1", opt_stats[5], iteration) + writer.add_scalar("optimizer/momentum_l1", opt_stats[6], iteration) + writer.add_scalar("optimizer/weight_l1", opt_stats[7], iteration) + writer.add_scalar( + "optimizer/variance_abs_max", opt_stats_2[0], iteration + ) + writer.add_scalar( + "optimizer/variance_sqrt_abs_max", opt_stats_2[1], iteration + ) + writer.add_scalar( + "optimizer/momentum_abs_max", opt_stats_2[2], iteration + ) + writer.add_scalar("optimizer/weight_abs_max", opt_stats_2[3], iteration) + + assert args is not None + assert timers is not None + if iteration % args.log_interval == 0: + elapsed_time = timers("interval-time").elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + seq_len = args.seq_length + if hasattr(args, "actual_seq_length"): + seq_len = args.actual_seq_length + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + model, args, elapsed_time, total_iterations + ) + samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size + tokens_per_sec = samples_per_sec * seq_len + tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size + tokens_per_gpu_per_second = tokens_per_sec / args.world_size + tokens_per_gpu_per_second_per_replica = ( + tokens_per_gpu_per_second / args.data_parallel_size + ) + # NOTE: [2024-06-19] + # Updated to use (more accurate) calculation according to + # `num_floating_point_operations` from NVIDIA/Megatron-LM + num_flop_lm = num_floating_point_operations(args, batch_size) + num_flop_per_sec_lm = num_flop_lm / elapsed_time_per_iteration + tflops_lm = num_flop_per_sec_lm / (10**12) + tflops_lm_per_gpu = tflops_lm / args.world_size + wandb_metrics |= { + "throughput/iteration-time": elapsed_time_per_iteration, # 1000 ms / s + "throughput/samples_per_sec": samples_per_sec, + "throughput/samples_per_sec_per_replica": samples_per_sec_per_replica, + "throughput/tokens_per_sec": tokens_per_sec, + "throughput/tokens_per_sec_per_replica": tokens_per_sec_per_replica, + "throughput/tokens_per_gpu_per_sec": tokens_per_gpu_per_second, + "throughput/tokens_per_gpu_per_sec_per_replica": tokens_per_gpu_per_second_per_replica, + "throughput/tflops": tflops, + "throughput/tflops-new": num_flop_lm / elapsed_time_per_iteration, + "throughput/tflops-lm": tflops_lm_per_gpu, + "throughput/approx_params_in_billions": approx_parameters_in_billions, + "throughput/elapsed_ms_per_iteration": elapsed_time_per_iteration, + "throughput/iteration": iteration, + } + if loss_dict is not None: + wandb_metrics |= { + "loss/iteration": iteration, + **{f"loss/{k}": v for k, v in loss_dict.items()}, + } + if writer and args.log_timers_to_tensorboard: + writer.add_scalar( + "iteration-time/iteration-time", elapsed_time_per_iteration, iteration + ) + writer.add_scalar( + "iteration-time/iteration-time vs samples", + elapsed_time_per_iteration, + args.consumed_train_samples, + ) + writer.add_scalar( + "iteration-time/iteration-time vs tokens", + elapsed_time_per_iteration, + args.consumed_train_tokens, + ) + # metrics_to_log = { + # 'iteration': iteration, + # 'train_iters': args.train_iters, + # 'consumed_samples': args.consumed_train_samples, + # 'consumed_tokens': args.consumed_tokens, + # } + log_string = f" iteration={iteration:8d}/{args.train_iters:8d} |" + # .format( iteration, args.train_iters) + log_string += ( + f" consumed_samples={args.consumed_train_samples:12d} |" + # .format(args.consumed_train_samples) + ) + log_string += f" consumed_tokens={args.consumed_train_tokens:12d} |" + # .format( args.consumed_train_tokens) + log_string += ( + " elapsed_time_per_iteration_ms=" + f"{elapsed_time_per_iteration * 1000.0:.1f} |" + # .format( elapsed_time_per_iteration * 1000.0) + ) + log_string += f" learning_rate={learning_rate:.6g} |" + log_string += f" global_batch_size={batch_size:5d} |" + # if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb_metrics |= { + "training/iteration": iteration, + "training/iteration_time": elapsed_time_per_iteration, + "training/iteration_time_vs_tokens": ( + elapsed_time_per_iteration / args.consumed_train_tokens + ), + "training/iteration_time_vs_samples": ( + (elapsed_time_per_iteration / args.consumed_train_samples), + ), + "training/consumed_samples": args.consumed_train_samples, + "training/consumed_tokens": args.consumed_train_tokens, + } + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: + avg = total_loss_dict[key].item() / float( + max(1, total_loss_dict[advanced_iters_key]) + ) + if avg > 0.0: + log_string += " {}={:.6f} |".format(key, avg) + total_loss_dict[key] = torch.tensor([0.0]).to(DEVICE) + if loss_scale is not None: + log_string += " loss_scale={:.1f} |".format(loss_scale) + wandb_metrics |= {"loss/loss_scale": loss_scale} + if grad_norm is not None: + log_string += " grad_norm={:.3f} |".format(grad_norm) + wandb_metrics |= {"loss/grad_norm": grad_norm} + if num_zeros_in_grad is not None: + log_string += " num_zeros={:.1f} |".format(num_zeros_in_grad) + wandb_metrics |= {"loss/num_zeros_in_grad": num_zeros_in_grad} + if params_norm is not None: + log_string += " params_norm={:.3f} |".format(params_norm) + wandb_metrics |= {"loss/params_norm": params_norm} + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + log_string += " curriculum_seqlen={:5d} |".format(args.curriculum_seqlen) + if args.random_ltd: + log_string += " random_ltd reserved_length={:5d} |".format( + args.random_ltd_reserved_length + ) + # log_string += " | ".join([ + # f"{seq_len=:5d} ", + # f"{}" + # f"number_of_skipped_iterations={:3d}", + # + # ]) + log_string += " actual_seqlen={:5d} |".format(seq_len) + log_string += " number_of_skipped_iterations={:3d} |".format( + total_loss_dict[skipped_iters_key] + ) + log_string += " number_of_nan_iterations={:3d} |".format( + total_loss_dict[nan_iters_key] + ) + log_string += " samples_per_second={:.3f} |".format(samples_per_sec) + log_string += " tokens_per_gpu_per_second_tgs={:.3f} |".format( + tokens_per_gpu_per_second + ) + log_string += " [LM]TFLOPs={:.2f} |".format(tflops_lm_per_gpu) + log_string += " [DS]TFLOPs={:.2f} |".format(tflops) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + # print_rank_last(log_string) + log.info(log_string) + if report_memory_flag and learning_rate > 0.0: + # Report memory after optimizer state has been initialized. + report_memory("(after {} iterations)".format(iteration)) + report_memory_flag = False + if wandb is not None and getattr(wandb, "run", None) is not None: + wandb_metrics |= { + "training/skiped_iterations": total_loss_dict[skipped_iters_key] + } + wandb_metrics |= {"training/nan_iterations": total_loss_dict[nan_iters_key]} + wandb.log(wandb_metrics) + if timers is not None: + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag diff --git a/megatron/utils.py b/megatron/utils.py index 8a9f3e7858..3d5eef4672 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -4,24 +4,14 @@ import sys import os -import time import logging -from typing import ContextManager, Optional +from typing import Optional import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from deepspeed.accelerator import get_accelerator -if get_accelerator().device_name() == "cuda": - try: - from apex.multi_tensor_apply import multi_tensor_applier - import amp_C - - HAS_APEX = True - except Exception: - HAS_APEX = False - from megatron import get_args, get_adlr_autoresume, get_num_microbatches from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate @@ -30,80 +20,115 @@ import ezpz as ez +ACCELERATOR = get_accelerator() +assert ACCELERATOR is not None + +if ACCELERATOR.device_name() == "cuda": + try: + from apex.multi_tensor_apply import multi_tensor_applier # type:ignore + import amp_C # type:ignore + + HAS_APEX = True + except Exception: + HAS_APEX = False + RANK = ez.get_rank() log = logging.getLogger(__name__) -# log.setLevel("INFO") if RANK == 0 else log.setLevel("CRITICAL") - +log.setLevel(os.environ.get("LOG_LEVEL", ("INFO" if RANK == 0 else "CRITICAL"))) _DLIO_PROFILER_EXIST = True -_DFTRACER_EXIST=True +_DFTRACER_EXIST = True try: - import dftracer -except: - _DFTRACER_EXIST=False - + import dftracer # type:ignore +except Exception: + _DFTRACER_EXIST = False + try: - import dlio_profiler -except: + import dlio_profiler # type:ignore +except Exception: _DLIO_PROFILER_EXIST = False - + if _DFTRACER_EXIST: - from dftracer.logger import dftracer as PerfTrace, dft_fn as Profile, DFTRACER_ENABLE as DFTRACER_ENABLE + from dftracer.logger import ( # type:ignore + dftracer as PerfTrace, + dft_fn as Profile, + DFTRACER_ENABLE as DFTRACER_ENABLE, + ) elif _DLIO_PROFILER_EXIST: - from dlio_profiler.logger import fn_interceptor as Profile - from dlio_profiler.logger import dlio_logger as PerfTrace + from dlio_profiler.logger import fn_interceptor as Profile # type:ignore + from dlio_profiler.logger import dlio_logger as PerfTrace # type:ignore else: from functools import wraps - # from contextlib import nullcontext - # Profile: ContextManager = nullcontext - # - # class Profile(nullable_schema) + class Profile(object): - def __init__(self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None): - return - def log(self, func): + def __init__( + self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None + ): + return + + def log(self, func): return func - def log_init(self, func): + + def log_init(self, func): return func - def iter(self, func, iter_name="step"): + + def iter(self, func, iter_name="step"): return func + def __enter__(self): return + def __exit__(self, type, value, traceback): return - def update(self, epoch=None, step=None, image_idx=None, image_size=None, args={}): + + def update( + self, epoch=None, step=None, image_idx=None, image_size=None, args={} + ): return + def flush(self): return + def reset(self): return + def log_static(self, func): return + class dftracer(object): - def __init__(self,): + def __init__( + self, + ): self.type = None + def initialize_log(self, logfile=None, data_dir=None, process_id=-1): return + def get_time(self): return + def enter_event(self): return + def exit_event(self): return + def log_event(self, name, cat, start_time, duration, string_args=None): return + def finalize(self): return PerfTrace = dftracer() DFTRACER_ENABLE = False + def get_logger( - name: str, - level: str = "INFO", - rank_zero_only: Optional[bool] = None, + name: str, + level: Optional[str] = None, + rank_zero_only: Optional[bool] = True, ) -> logging.Logger: """Returns a `logging.Logger` object. @@ -111,7 +136,9 @@ def get_logger( non-zero ranks (and will be set to `level` on RANK==0). """ logger = logging.getLogger(name) - logger.setLevel(level) + logger.setLevel( + str(level if level is not None else os.environ.get("LOG_LEVEL", "INFO")).upper() + ) if rank_zero_only and ez.get_rank() != 0: logger.setLevel("CRITICAL") return logger @@ -119,7 +146,8 @@ def get_logger( def update_rotary_pos_emb(seq_length): args = get_args() - assert args is not None + accelerator = get_accelerator() + assert args is not None and accelerator is not None rotary_dim = ( args.hidden_size // args.num_attention_heads if args.kv_channels is None @@ -133,7 +161,7 @@ def update_rotary_pos_emb(seq_length): # Wang and Komatsuzaki et al # https://github.com/kingoflolz/mesh-transformer-jax/ rotary_pos_emb = RotaryEmbedding(rotary_dim, theta=args.rope_theta)(seq_length).to( - get_accelerator().current_device_name() + accelerator.current_device_name() ) args.rotary_pos_emb = rotary_pos_emb @@ -203,21 +231,22 @@ def average_losses_across_data_parallel_group(losses): def report_memory(name): """Simple GPU memory report.""" + accelerator = get_accelerator() + assert accelerator is not None mega_bytes = 1024.0 * 1024.0 string = name + " memory (MB)" - string += " | allocated: {}".format( - get_accelerator().memory_allocated() / mega_bytes - ) + string += " | allocated: {}".format(accelerator.memory_allocated() / mega_bytes) string += " | max allocated: {}".format( - get_accelerator().max_memory_allocated() / mega_bytes - ) - string += " | reserved: {}".format(get_accelerator().memory_reserved() / mega_bytes) - string += " | max reserved: {}".format( - get_accelerator().max_memory_reserved() / mega_bytes + accelerator.max_memory_allocated() / mega_bytes ) + reserved = accelerator.memory_reserved() + max_reserved = accelerator.max_memory_reserved() + if reserved is not None: + string += " | reserved: {}".format(reserved / mega_bytes) + if max_reserved is not None: + string += " | max reserved: {}".format(max_reserved / mega_bytes) if mpu.get_data_parallel_rank() == 0: log.info(f"[Rank {RANK}] {string}") - # log.info("[Rank {}] {}".format(torch.distributed.get_rank(), string)) # , flush=True) def print_params_min_max_norm(optimizer, iteration): @@ -236,19 +265,19 @@ def print_params_min_max_norm(optimizer, iteration): iteration, rank, index, int(param.tensor_model_parallel) ) string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm) - # print(string, flush=True) log.info(string) def check_adlr_autoresume_termination(iteration, model, optimizer, opt_param_scheduler): """Check for autoresume signal and exit if it is received.""" from megatron.checkpointing import save_checkpoint + args = get_args() assert args is not None autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. torch.distributed.barrier() - if autoresume.termination_requested(): + if autoresume is not None and autoresume.termination_requested(): if args.save: save_checkpoint(iteration, model, optimizer, opt_param_scheduler) print_rank_0(">>> autoresume termination request found!") @@ -280,7 +309,7 @@ def get_ltor_masks_and_position_ids( attention_mask = None if not skip_mask: attention_mask = torch.tril( - torch.ones((att_mask_batch, seq_length, seq_length)) + torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) ).view(att_mask_batch, 1, seq_length, seq_length) # Loss mask. @@ -309,7 +338,11 @@ def get_ltor_masks_and_position_ids( for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. - if reset_attention_mask and not skip_mask: + if ( + reset_attention_mask + and not skip_mask + and attention_mask is not None + ): attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 # Reset positions. if reset_position_ids: @@ -320,7 +353,6 @@ def get_ltor_masks_and_position_ids( if not skip_mask: assert attention_mask is not None attention_mask = attention_mask < 0.5 - attention_mask = attention_mask.to(data.device) return attention_mask, loss_mask, position_ids @@ -360,10 +392,7 @@ def is_rank_0(): if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0 or ( is_aml() - and ( - torch.distributed.get_rank() - % get_accelerator().device_count() - ) == 0 + and (torch.distributed.get_rank() % get_accelerator().device_count()) == 0 ): return True else: @@ -392,6 +421,37 @@ def get_parameters_in_billions(model): return approx_parameters_in_billions * gpus_per_model / (1e9) +def num_floating_point_operations(args, batch_size): + # Group Query Attention. + # if not args.group_query_attention: + if not args.num_key_value_heads: + args.num_key_value_heads = args.num_attention_heads + # args.num_query_groups = args.num_attention_heads + # MoE. + # num_experts_routed_to = 1 if args.num_experts is None else args.moe_router_topk + num_experts_routed_to = 1 if args.num_experts is None else args.topk + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + return ( + 12 + * batch_size + * args.seq_length + * args.num_layers + * args.hidden_size + * args.hidden_size + * ( + 1 + + ( + (args.ffn_hidden_size / args.hidden_size) + * num_experts_routed_to + * gated_linear_multiplier + ) + + (args.num_key_value_heads / args.num_attention_heads) + + (args.seq_length / args.hidden_size) + + (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size)) + ) + ) + + def throughput_calculator(model, args, iteration_time, total_iterations): batch_size = ( args.micro_batch_size * get_num_microbatches() * args.data_parallel_size @@ -404,36 +464,58 @@ def throughput_calculator(model, args, iteration_time, total_iterations): # flops calculator hidden_size = args.hidden_size + num_attention_heads = args.num_attention_heads + head_dim = hidden_size // num_attention_heads + ffn_hidden_size = args.ffn_hidden_size num_layers = args.num_layers vocab_size = args.padded_vocab_size + gqa = args.num_attention_heads // args.num_key_value_heads + ffn_multiplier = 3 if args.swiglu else 2 + macs_per_flops = 2 # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of # https://arxiv.org/pdf/2104.04473.pdf). - # The factor of 4 is when used with activation check-pointing, - # otherwise it will be 3. - checkpoint_activations_factor = 3 - if hasattr(args, "checkpoint_activations") and args.checkpoint_activations: - checkpoint_activations_factor = 4 - if hasattr(args, "recompute_granularity") and ( - args.recompute_granularity == "selective" - or args.recompute_granularity == "full" - ): - checkpoint_activations_factor = 4 + # correction has been made to TFLOPs formula due to incorrect behavior + # observed with selective recompute when GQA not used and for all with GQA seq_len = args.seq_length if hasattr(args, "actual_seq_length"): seq_len = args.actual_seq_length - flops_per_iteration = ( - 24 - * checkpoint_activations_factor - * batch_size - * seq_len + pre_and_post_mha_gemm_macs = ( + batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len + ) + mha_bgemm_macs = ( + batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2) + ) + ffn_gemm_macs = ( + batch_size * num_layers - * (hidden_size**2) - ) * ( - 1.0 - + (seq_len / (6.0 * hidden_size)) - + (vocab_size / (16.0 * num_layers * hidden_size)) + * ffn_multiplier + * ffn_hidden_size + * hidden_size + * seq_len ) + logit_lmhead_gemm_macs = batch_size * vocab_size * hidden_size * seq_len + + fwd_macs = ( + pre_and_post_mha_gemm_macs + + mha_bgemm_macs + + ffn_gemm_macs + + logit_lmhead_gemm_macs + ) + bwd_macs = 2 * fwd_macs + fwd_bwd_macs = fwd_macs + bwd_macs + + if (hasattr(args, "checkpoint_activations") and args.checkpoint_activations) or ( + hasattr(args, "recompute_granularity") and args.recompute_granularity == "full" + ): + fwd_bwd_macs += fwd_macs + if ( + hasattr(args, "recompute_granularity") + and args.recompute_granularity == "selective" + ): + fwd_bwd_macs += mha_bgemm_macs + + flops_per_iteration = fwd_bwd_macs * macs_per_flops tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions @@ -490,42 +572,54 @@ def dump_weights(preamble, iteration, model, optimizer, tensor=None): dp_rank = mpu.get_data_parallel_rank() dp_size = mpu.get_data_parallel_world_size() fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + # only care for first and last pp stages and dp0 tp0 # if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()): # return + # if not (tp_rank == 0 and dp_rank == 0): # return + if tensor is not None: orig_tensor = tensor if hasattr(tensor, "_hp_param"): numel = tensor._hp_param.numel() # // dp_size tensor = tensor.flatten().narrow(0, 0, numel) + # print(fn) with open(fn, "w") as fh: fh.write(f"{get_fingerprint_header()}\n") + if tensor is not None: fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") else: for n, p in model[0].named_parameters(): fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n") + + # # until we figure out how to dump the actual fp32 values don't do this + # fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" + # with open(fn, "w") as fh: + # fh.write(f"{get_fingerprint_header()}\n") + # if tensor is not None: + # tensor = orig_tensor + # if hasattr(tensor, "_hp_param"): + # fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n") + # #fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n") + # else: + # fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") + # #fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n") + # + # else: + # if hasattr(model[0].module.tied_modules, "embed"): + # p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param + # fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n") return - # until we figure out how to dump the actual fp32 values don't do this - fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt" - with open(fn, "w") as fh: - fh.write(f"{get_fingerprint_header()}\n") - if tensor is not None: - tensor = orig_tensor - if hasattr(tensor, "_hp_param"): - fh.write( - f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n" - ) - # fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n") - else: - fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n") - # fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n") - else: - if hasattr(model[0].module.tied_modules, "embed"): - p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param - fh.write( - f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n" - ) + + +def found_kill_switch(): + args = get_args() + assert args is not None + if args.kill_switch_file is not None and os.path.exists(args.kill_switch_file): + return True + else: + return False diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 04018d7918..3686c6ceeb 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -1,12 +1,14 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" + import time +from typing import Callable from mpi4py import MPI + comm = MPI.COMM_WORLD comm.Barrier() python_start_time = time.time() -from pathlib import Path import os from rich import print @@ -14,6 +16,7 @@ import math from functools import partial from megatron import get_args + # from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer @@ -23,14 +26,19 @@ from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids -from megatron.utils import average_losses_across_data_parallel_group, update_rotary_pos_emb +from megatron.utils import ( + average_losses_across_data_parallel_group, + update_rotary_pos_emb, +) from megatron.arguments import core_transformer_config_from_args + # from megatron.utils import Profile, PerfTrace import logging import deepspeed from deepspeed.runtime.utils import see_memory_usage + # from deepspeed.accelerator.real_accelerator import get_accelerator import subprocess import wandb @@ -38,7 +46,8 @@ from torch import nn import torch.nn.functional as F import ezpz as ez -dt_imports = time.time() - python_start_time + +dt_imports = time.time() - python_start_time t0_setup = time.time() # ---- [SETUP COMMS] ------------------------ @@ -62,19 +71,12 @@ log.info(f"ez.setup_torch time: {dt_setup} seconds") # ---- [SETUP WANDB FROM RANK 0] -------------- -WANDB_MODE = os.environ.get('WANDB_MODE', None) -DISABLE_WANDB = ( - WANDB_MODE is not None and str(WANDB_MODE).lower() == 'disabled' -) +WANDB_MODE = os.environ.get("WANDB_MODE", None) +DISABLE_WANDB = WANDB_MODE is not None and str(WANDB_MODE).lower() == "disabled" if RANK == 0 and not DISABLE_WANDB: - project_name = ( - os.environ.get( - 'WB_PROJECT', # look for WB_PROJECT in env - os.environ.get( - 'WANDB_PROJECT', # look for WANDB_PROJECT in env - 'AuroraGPT' - ), - ) + project_name = os.environ.get( + "WB_PROJECT", # look for WB_PROJECT in env + os.environ.get("WANDB_PROJECT", "AuroraGPT"), # look for WANDB_PROJECT in env ) log.info(f"Setting up W&B from: {RANK} with {project_name}") _ = ez.setup_wandb(project_name=project_name) @@ -83,16 +85,16 @@ @ez.dist.timeitlogit(rank=RANK) def model_provider(pre_process=True, post_process=True): """Build the model.""" - log.info('building GPT model ...') + log.info("building GPT model ...") see_memory_usage("Before Building Model", force=True) args = get_args() assert args is not None config = core_transformer_config_from_args(args) # if RANK == 0: # git_ds_info() - if hasattr(mpu, 'get_sequence_data_parallel_group'): + if hasattr(mpu, "get_sequence_data_parallel_group"): dpg = mpu.get_sequence_data_parallel_group() - elif hasattr(mpu, 'get_data_parallel_group'): + elif hasattr(mpu, "get_data_parallel_group"): dpg = mpu.get_data_parallel_group() else: dpg = None @@ -100,20 +102,14 @@ def model_provider(pre_process=True, post_process=True): if args.use_mics: deepspeed_zero_init = deepspeed.zero.MiCS_Init with deepspeed_zero_init( - data_parallel_group=dpg, - remote_device=( - None if args.remote_device == 'none' else args.remote_device - ), - config_dict_or_path=args.deepspeed_config_dict, - enabled=args.zero_stage == 3, - mpu=mpu + data_parallel_group=dpg, + remote_device=(None if args.remote_device == "none" else args.remote_device), + config_dict_or_path=args.deepspeed_config, # _dict, + enabled=args.zero_stage == 3, + mpu=mpu, ): if args.deepspeed and not args.no_pipeline_parallel: - model = GPTModelPipe( - config=config, - num_tokentypes=0, - parallel_output=True - ) + model = GPTModelPipe(config=config, num_tokentypes=0, parallel_output=True) # This is a hack to give us a reference to # get_batch_pipe from within training.py # We need to call model.set_batch_fn after deepspeed.initialize @@ -129,7 +125,7 @@ def model_provider(pre_process=True, post_process=True): ) ).view(1, 1, args.seq_length, args.seq_length) # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) + attention_mask = attention_mask < 0.5 if args.fp16: attention_mask = attention_mask.half() elif args.bf16: @@ -146,37 +142,33 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, pre_process=pre_process, - post_process=post_process + post_process=post_process, ) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - log.info(80 * '-') + log.info(80 * "-") log.info(f"Number of parameters in model: {num_params}") - log.info(80 * '-') + log.info(80 * "-") see_memory_usage("After Building Model", force=True) - if wandb is not None and getattr(wandb, 'run', None) is not None: + if wandb is not None and getattr(wandb, "run", None) is not None: assert wandb.run is not None tbdir = args.tensorboard_dir # tbdir = args.getattr('tensorboard_dir', None) if tbdir is not None: try: - log.info(f'Patching tensorboard from {tbdir}') + log.info(f"Patching tensorboard from {tbdir}") wandb.tensorboard.patch(root_logdir=tbdir) except ValueError as exc: log.exception(exc) - log.warning('Continuing without patching tensorboard!') - wandb.run.config.update({'num_params': num_params}) + log.warning("Continuing without patching tensorboard!") + wandb.run.config.update({"num_params": num_params}) if "args" not in wandb.run.config: log.info( f"Updating WandB run.config: [{wandb.run.name}]({wandb.run.get_url()})" ) try: - wandb.run.config.update( - {"args": dict(sorted(vars(args).items()))} - ) + wandb.run.config.update({"args": dict(sorted(vars(args).items()))}) except Exception: - log.error( - 'Unable to `wandb.run.config.update({"args": vars(args)})`' - ) + log.error('Unable to `wandb.run.config.update({"args": vars(args)})`') # try: # wandb.run.watch( # model, @@ -194,9 +186,17 @@ def get_batch(data_iterator): tokenizer = get_tokenizer() assert args is not None and tokenizer is not None # Items and their type. - keys = ['text'] + keys = ["text"] datatype = torch.int64 data = next(data_iterator) if data_iterator is not None else None + + if ( + args.iteration < 10 + and RANK == 0 + and os.environ.get("DUMP_TOKENS", None) + and data is not None + ): + log.info(f"{args.iteration=}: {data['text'][:10]=}") # # Broadcast data. # if data_iterator is not None: # data = next(data_iterator) @@ -204,7 +204,7 @@ def get_batch(data_iterator): # data = None data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. - tokens_ = data_b['text'].long() + tokens_ = data_b["text"].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. @@ -215,7 +215,8 @@ def get_batch(data_iterator): args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, - skip_mask) + skip_mask, + ) # For DS's sequence parallel seq_parallel_world_size = mpu.get_sequence_parallel_world_size() seq_parallel_world_rank = mpu.get_sequence_parallel_rank() @@ -240,24 +241,37 @@ def data_post_process(data, data_sampler_state_dict): args = get_args() assert args is not None if args.data_efficiency_curriculum_learning: - if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: - args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' - current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if "seqlen_truncate" in data_sampler_state_dict["current_difficulties"]: + args.data_efficiency_curriculum_learning_seqlen_type = "seqlen_truncate" + current_seqlen = data_sampler_state_dict["current_difficulties"][ + "seqlen_truncate" + ] if current_seqlen < args.seq_length: - data['text'] = data['text'][:, :(current_seqlen+1)].contiguous() - elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: - args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' - current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + data["text"] = data["text"][:, : (current_seqlen + 1)].contiguous() + elif "seqlen_reshape" in data_sampler_state_dict["current_difficulties"]: + args.data_efficiency_curriculum_learning_seqlen_type = "seqlen_reshape" + current_seqlen = data_sampler_state_dict["current_difficulties"][ + "seqlen_reshape" + ] if current_seqlen < args.seq_length: - orig_num_token = torch.numel(data['text']) - reshape_len = (data['text'].size()[1] // (current_seqlen+1)) * (current_seqlen+1) - data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen+1), - data['text'][:, -(current_seqlen+1):]), 0).contiguous() - num_row = math.ceil(orig_num_token / (current_seqlen+1)) - num_row = min(num_row, data['text'].size()[0]) + orig_num_token = torch.numel(data["text"]) + reshape_len = (data["text"].size()[1] // (current_seqlen + 1)) * ( + current_seqlen + 1 + ) + data["text"] = torch.cat( + ( + data["text"][:, :reshape_len] + .contiguous() + .view(-1, current_seqlen + 1), + data["text"][:, -(current_seqlen + 1) :], + ), + 0, + ).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen + 1)) + num_row = min(num_row, data["text"].size()[0]) if num_row > 1 and num_row % 2 != 0: num_row -= 1 - data['text'] = data['text'][:num_row, :].contiguous() + data["text"] = data["text"][:num_row, :].contiguous() else: args.data_efficiency_curriculum_learning_seqlen_type = None return data @@ -272,12 +286,12 @@ def get_batch_pipe(data): tokenizer = get_tokenizer() assert args is not None # Items and their type. - keys = ['text'] + keys = ["text"] datatype = torch.int64 # Broadcast data. data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. - tokens_ = data_b['text'].long() + tokens_ = data_b["text"].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. @@ -286,19 +300,17 @@ def get_batch_pipe(data): tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, - args.eod_mask_loss) - if ( - args.curriculum_learning_legacy - and args.curriculum_seqlen < tokens.size()[1] - ): + args.eod_mask_loss, + ) + if args.curriculum_learning_legacy and args.curriculum_seqlen < tokens.size()[1]: # seqlen-based curriculum learning # tokens, position_ids, labels, loss_mask # have size [batch size, seqlen] - tokens = tokens[:, :args.curriculum_seqlen].contiguous() - position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + tokens = tokens[:, : args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, : args.curriculum_seqlen].contiguous() if labels is not None: - labels = labels[:, :args.curriculum_seqlen].contiguous() - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + labels = labels[:, : args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, : args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) @@ -315,37 +327,32 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): loss = loss + moe_loss + mos_loss if args.mos: return loss, { - 'total loss': loss, - 'lm loss': averaged_loss[0], - 'moe loss': moe_loss, - 'mos loss': mos_loss + "total loss": loss, + "lm loss": averaged_loss[0], + "moe loss": moe_loss, + "mos loss": mos_loss, } elif args.kd: return loss, { - 'total loss': loss, - 'lm loss': averaged_loss[0], - 'moe loss': moe_loss, - 'kd loss': mos_loss + "total loss": loss, + "lm loss": averaged_loss[0], + "moe loss": moe_loss, + "kd loss": mos_loss, } log.info( - f'>>> total loss: {loss}, ' - f'lm loss {averaged_loss[0]}, ' - f'kd loss {mos_loss}' + f">>> total loss: {loss}, " + f"lm loss {averaged_loss[0]}, " + f"kd loss {mos_loss}" ) else: if max(args.num_experts) <= 1: - return loss, {'lm loss': averaged_loss[0]} + return loss, {"lm loss": averaged_loss[0]} loss = loss + moe_loss - return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + return loss, {"lm loss": averaged_loss[0], "moe loss": moe_loss} def calculate_mos_loss( - args, - stu_output, - teacher_model, - tokens, - position_ids, - attention_mask + args, stu_output, teacher_model, tokens, position_ids, attention_mask ): mos_loss = 0 alpha = args.kd_alpha_ce @@ -354,29 +361,25 @@ def calculate_mos_loss( if teacher_model: with torch.no_grad(): if ( - args.curriculum_learning_legacy and - args.curriculum_seqlen < args.seq_length + args.curriculum_learning_legacy + and args.curriculum_seqlen < args.seq_length ): assert args.curriculum_seqlen is not None curriculum_seqlen = args.curriculum_seqlen tokens = tokens[:, :curriculum_seqlen].contiguous() position_ids = position_ids[:, :curriculum_seqlen].contiguous() csl = curriculum_seqlen - attention_mask = ( - attention_mask[:, :, :csl, :csl].contiguous() - ) + attention_mask = attention_mask[:, :, :csl, :csl].contiguous() # No need to truncate labels # as we do not need it for the teacher logits tea_output, tea_other_losses = teacher_model( - tokens, - position_ids, - attention_mask + tokens, position_ids, attention_mask ) assert stu_output.size() == tea_output.size(), ( - 'teacher and student output should match in size. ' - f'Student: {stu_output.size()}, ' - f'Teacher: {tea_output.size()}, ' - f'CL seq length {args.curriculum_seqlen}' + "teacher and student output should match in size. " + f"Student: {stu_output.size()}, " + f"Teacher: {tea_output.size()}, " + f"CL seq length {args.curriculum_seqlen}" ) student_logits = F.log_softmax(stu_output / kd_temp, dim=2) # The target logits is expected to be probabilities. @@ -384,67 +387,48 @@ def calculate_mos_loss( # then we need to set target_log to true # when initializing the KLDivLoss. tea_logits = F.softmax(tea_output / kd_temp, dim=2) - mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')( - student_logits, - tea_logits + mos_loss = ( + kd_temp + * kd_temp + * nn.KLDivLoss(reduction="batchmean")(student_logits, tea_logits) ) mos_loss = mos_loss.div(args.seq_length) * beta return mos_loss -def forward_step(data_iterator, model): +def forward_step(data_iterator, model) -> tuple[torch.Tensor | None, Callable]: """Forward step.""" args = get_args() timers = get_timers() assert args is not None assert timers is not None # Get the batch. - timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator - ) - timers('batch-generator').stop() + timers("batch-generator", log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) + timers("batch-generator").stop() if args.data_efficiency_curriculum_learning: args.curriculum_seqlen = tokens.size()[1] - if ( - hasattr( - args, - 'data_efficiency_curriculum_learning_seqlen_type') - and ( - args.data_efficiency_curriculum_learning_seqlen_type - == 'seqlen_reshape' - ) + if hasattr(args, "data_efficiency_curriculum_learning_seqlen_type") and ( + args.data_efficiency_curriculum_learning_seqlen_type == "seqlen_reshape" ): - args.data_efficiency_curriculum_learning_numel = ( - torch.numel(tokens) - ) + args.data_efficiency_curriculum_learning_numel = torch.numel(tokens) stu_output = None if args.mos or args.kd: # The forward func can return either the loss or the logits, # depending on whether passing in the labels or not. stu_output, other_losses = model(tokens, position_ids, attention_mask) - if ( - args.curriculum_learning_legacy - and args.curriculum_seqlen < args.seq_length - ): + if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: assert args.curriculum_seqlen is not None - labels = labels[:, :args.curriculum_seqlen].contiguous() + labels = labels[:, : args.curriculum_seqlen].contiguous() output_tensor = tensor_parallel.vocab_parallel_cross_entropy( - stu_output.contiguous().float(), - labels + stu_output.contiguous().float(), labels ) else: output_tensor, other_losses = model( - tokens, - position_ids, - attention_mask, - labels=labels + tokens, position_ids, attention_mask, labels=labels ) - if ( - args.curriculum_learning_legacy and - args.curriculum_seqlen < args.seq_length - ): - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: + loss_mask = loss_mask[:, : args.curriculum_seqlen].contiguous() moe_losses = [] for moe_loss in other_losses: @@ -462,7 +446,7 @@ def forward_step(data_iterator, model): args.teacher_model[0], tokens, position_ids, - attention_mask + attention_mask, ) # Output_tensor stores the standard loss, @@ -479,7 +463,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): # from ezpz.profile import get_context_manager # cm = get_context_manager(rank=RANK, outdir=args.save) # with cm: - log.info('> building train, validation, and test datasets for GPT ...') + log.info("> building train, validation, and test datasets for GPT ...") files = [] if args.data_file_list is not None: log.info(f"Reading datasets from {args.data_file_list}") @@ -492,7 +476,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): # - `/path/to/data_text_document` is the path to the text document # - `corpus` is the corpus (~ source, can be made up) where that # document came from (i.e. `books`, `arxiv`, etc.) - with open(args.data_file_list, 'r') as flist: + with open(args.data_file_list, "r") as flist: for f in flist.readlines(): if len(f.strip()) != 0: try: @@ -505,17 +489,11 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ) if fname.find(".bin") != -1: fname = fname.split(".bin")[0] - files.extend( - [ - float(w), # weight - fname, # filename - c # corpus - ] - ) + files.extend([float(w), fname, c]) # weight # filename # corpus elif len(args.data_path) == 1 and os.path.isdir(args.data_path[0]): path = args.data_path[0] + "/" for f in os.listdir(path): - if (os.path.isfile(path + f) and f.find(".bin") != -1): + if os.path.isfile(path + f) and f.find(".bin") != -1: files.append(1) files.append(path + f.split(".bin")[0]) else: @@ -540,11 +518,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): def command_exists(cmd): - result = subprocess.Popen( - f'type {cmd}', - stdout=subprocess.PIPE, - shell=True - ) + result = subprocess.Popen(f"type {cmd}", stdout=subprocess.PIPE, shell=True) return result.wait() == 0 @@ -552,17 +526,18 @@ def git_ds_info(): if RANK != 0: return from deepspeed.env_report import main as ds_report + ds_report() # Write out version/git info git_hash_cmd = "git rev-parse --short HEAD" git_branch_cmd = "git rev-parse --abbrev-ref HEAD" - if command_exists('git'): + if command_exists("git"): try: result = subprocess.check_output(git_hash_cmd, shell=True) - git_hash = result.decode('utf-8').strip() + git_hash = result.decode("utf-8").strip() result = subprocess.check_output(git_branch_cmd, shell=True) - git_branch = result.decode('utf-8').strip() + git_branch = result.decode("utf-8").strip() except subprocess.CalledProcessError: git_hash = "unknown" git_branch = "unknown" @@ -570,21 +545,26 @@ def git_ds_info(): git_hash = "unknown" git_branch = "unknown" print( - f'**** Git info for Megatron: ' - f'git_hash={git_hash} git_branch={git_branch} ****' + f"**** Git info for Megatron: " + f"git_hash={git_hash} git_branch={git_branch} ****" ) def main(): - if os.getenv('TORCH_PROFILER_ENABLE') == '1': + if os.getenv("TORCH_PROFILER_ENABLE") == "1": # record_function from torch.profiler import profile, ProfilerActivity + try: - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU] + activities = [ + ProfilerActivity.CPU, + ProfilerActivity.CUDA, + ProfilerActivity.XPU, + ] except Exception as exc: log.exception(exc) - log.warning("TORCH PROFILER WARNING: XPU is not supported") - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] + log.warning("TORCH PROFILER WARNING: XPU is not supported") + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] with profile(activities=activities) as prof: model = pretrain( train_valid_test_datasets_provider, @@ -592,7 +572,7 @@ def main(): ModelType.encoder_or_decoder, forward_step, # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process + data_post_process=data_post_process, ) args = get_args() assert args is not None @@ -606,7 +586,7 @@ def main(): ModelType.encoder_or_decoder, forward_step, # args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process + data_post_process=data_post_process, ) # try: # from megatron.text_generation import generate_and_post_process @@ -641,6 +621,7 @@ def main(): # data_post_process=data_post_process) import sys import deepspeed.comm as dist + model = main() dist.log_summary() if wandb.run is not None: diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 4f758bed7d..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -pydftracer -wandb -git+https://github.com/saforem2/ezpz@main diff --git a/test_agptllama.py b/test_agptllama.py new file mode 100644 index 0000000000..e1d207fa27 --- /dev/null +++ b/test_agptllama.py @@ -0,0 +1,34 @@ +import torch +import intel_extension_for_pytorch as ipex +from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaTokenizer, LlamaForCausalLM +def batch_encode(prompts, tokenizer, prompt_len=512): + input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding="max_length", max_length=len(prompts)) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to("xpu") + #input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + return input_tokens + + +def generate_prompt(model, tokenizer, prompts): + + input_tokens = batch_encode(prompts, tokenizer) + print(input_tokens) + generate_kwargs = dict(max_new_tokens=30, do_sample=False) + output_ids = model.generate(**input_tokens, **generate_kwargs) + print(output_ids) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + return outputs + +if __name__ == '__main__': + + model = LlamaForCausalLM.from_pretrained("/flare/Aurora_deployment/vsastry/hf_new_cp/") + model.to("xpu") # model.cuda() + model.seqlen = 4096 + + # get llama tokenizer + tokenizer = LlamaTokenizer.from_pretrained("/flare/Aurora_deployment/AuroraGPT/datasets/dolma/utils/tokenizer.model") + tokenizer.pad_token = tokenizer.eos_token + output = generate_prompt(model, tokenizer, prompts=["What is the language spoken in Mexico ?"]) + print(output) diff --git a/tests/models/test_gpt_embedding.py b/tests/models/test_gpt_embedding.py index 700990adc2..199f29dede 100644 --- a/tests/models/test_gpt_embedding.py +++ b/tests/models/test_gpt_embedding.py @@ -1,15 +1,22 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest import torch +import types from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.models.gpt.gpt_embedding import GPTEmbedding +from megatron.global_vars import set_args +from deepspeed.accelerator import get_accelerator +device_name = get_accelerator().device_name() @pytest.fixture def gpt_embedding(transformer_config): + args = types.SimpleNamespace(params_dtype=torch.float32, embed_layernorm=False) + set_args(args) embedding = GPTEmbedding(config=transformer_config, vocab_size=100, max_sequence_length=4) return embedding @@ -36,12 +43,12 @@ def test_cpu_forward(self, gpt_embedding: GPTEmbedding): assert embeddings.shape[1] == input_ids.shape[0] assert embeddings.shape[2] == gpt_embedding.config.hidden_size - def test_gpu_forward(self, gpt_embedding: GPTEmbedding): - gpt_embedding.cuda() - input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() - position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() + def test_accelerator_forward(self, gpt_embedding: GPTEmbedding): + gpt_embedding.to(device_name) + input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).to(device_name) + position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).to(device_name) embeddings = gpt_embedding(input_ids, position_ids) - assert embeddings.device.type == 'cuda' + assert embeddings.device.type == device_name assert embeddings.shape[0] == gpt_embedding.max_sequence_length assert embeddings.shape[1] == input_ids.shape[0] assert embeddings.shape[2] == gpt_embedding.config.hidden_size diff --git a/tests/models/test_gpt_model.py b/tests/models/test_gpt_model.py index b854ecd918..cf322908b3 100644 --- a/tests/models/test_gpt_model.py +++ b/tests/models/test_gpt_model.py @@ -1,20 +1,28 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest import torch +import types from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.global_vars import set_args +from deepspeed.accelerator import get_accelerator +device_name = get_accelerator().device_name() @pytest.fixture def gpt_model(transformer_config): + args = types.SimpleNamespace(params_dtype=torch.float32, embed_layernorm=False) + set_args(args) language_model = GPTModel(config=transformer_config, vocab_size=100, max_sequence_length=4) return language_model class TestGPTModel: + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_constructor(self, gpt_model: GPTModel): assert isinstance(gpt_model, GPTModel) @@ -23,6 +31,7 @@ def test_constructor(self, gpt_model: GPTModel): num_weights = sum([p.numel() for p in gpt_model.parameters()]) assert num_weights == 5040 + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_set_input_tensor(self, gpt_model: GPTModel): config: TransformerConfig = gpt_model.config sequence_length = gpt_model.max_sequence_length @@ -37,17 +46,18 @@ def test_set_input_tensor(self, gpt_model: GPTModel): assert gpt_model.decoder.input_tensor.shape[1] == micro_batch_size assert gpt_model.decoder.input_tensor.shape[2] == config.hidden_size + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_post_process_forward(self, gpt_model: GPTModel): config: TransformerConfig = gpt_model.config sequence_length = gpt_model.max_sequence_length micro_batch_size = 2 - gpt_model.cuda() + gpt_model.to(device_name) data = list(range(sequence_length)) - input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() + input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device_name) + position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).to(device_name) + attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).to(device_name) logits = gpt_model.forward(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) @@ -55,15 +65,19 @@ def test_post_process_forward(self, gpt_model: GPTModel): assert logits.shape[1] == sequence_length assert logits.shape[2] == gpt_model.vocab_size + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_no_post_process_forward(self, gpt_model: GPTModel): pass + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_no_preprocess_forward(self, gpt_model: GPTModel): pass + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_state_dict_for_save_checkpoint(self, gpt_model: GPTModel): pass + @pytest.mark.xfail(device_name=='hpu', reason="TELayerNorm is not defined in HPU") def test_load_state_dict(self, gpt_model: GPTModel): pass diff --git a/tests/pipeline_parallel/test_schedules.py b/tests/pipeline_parallel/test_schedules.py index a6bac5b2a3..72c2372ba4 100644 --- a/tests/pipeline_parallel/test_schedules.py +++ b/tests/pipeline_parallel/test_schedules.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import torch from tests.test_utilities import Utils from megatron.core import ModelParallelConfig @@ -21,7 +23,9 @@ def test_get_forward_backward_func(): def test_deallocate_output_tensor(): out = torch.tensor([[1, 2, 3], [4, 5, 6]]) schedule.deallocate_output_tensor(out) - assert(out.nelement() == 1) + assert(out.nelement() == 6) + schedule.deallocate_output_tensor(out, True) + assert(out.nelement() == 1) def test_forward_backward_func_without_pipeline_parallel(mocker): from megatron.core.pipeline_parallel import get_forward_backward_func diff --git a/tests/transformer/test_parallel_mlp.py b/tests/transformer/test_parallel_mlp.py index f43dc0b467..098f18a9d6 100644 --- a/tests/transformer/test_parallel_mlp.py +++ b/tests/transformer/test_parallel_mlp.py @@ -1,14 +1,30 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest import torch +import types from megatron.core.transformer.parallel_mlp import ParallelMLP +from megatron.global_vars import set_args +from deepspeed.accelerator import get_accelerator +device_name = get_accelerator().device_name() @pytest.fixture def mlp(transformer_config): + mlp_args = types.SimpleNamespace( + swiglu=False, + openai_gelu=True, + onnx_safe=False, + bias_gelu_fusion=False, + transformer_impl="", + cache_fp8_weight=False, + fp8_interval=False, + cache_fp8_weight_fwd=False + ) + set_args(mlp_args) return ParallelMLP(transformer_config) @@ -19,28 +35,27 @@ def test_constructor(self, mlp): num_weights = sum([p.numel() for p in mlp.parameters()]) assert num_weights == 1212 - def test_cpu_forward(self, mlp): + def test_cpu_forward(self, mlp, transformer_config): # [sequence length, micro batch size, hidden size] - hidden_states = torch.ones((32, 2, mlp.config.hidden_size)) + hidden_states = torch.ones((32, 2, transformer_config.hidden_size)) output, output_bias = mlp(hidden_states) assert output.shape[0] == 32 assert output.shape[1] == 2 - assert output.shape[2] == mlp.config.hidden_size - assert output_bias.shape[0] == mlp.config.hidden_size + assert output.shape[2] == transformer_config.hidden_size + assert output_bias == None assert output.dtype == torch.float32 - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_gpu_forward(self, mlp): - mlp.cuda() + @pytest.mark.skipif(not get_accelerator().is_available(), reason="accelerator not available") + def test_accelerator_forward(self, mlp, transformer_config): + mlp.to(device_name) # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, mlp.config.hidden_size)) - hidden_states = hidden_states.cuda() + hidden_states = torch.ones((32, 2, transformer_config.hidden_size)) + hidden_states = hidden_states.to(device_name) output, output_bias = mlp(hidden_states) assert output.shape[0] == 32 assert output.shape[1] == 2 - assert output.shape[2] == mlp.config.hidden_size - assert output_bias.shape[0] == mlp.config.hidden_size + assert output.shape[2] == transformer_config.hidden_size + assert output_bias == None assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - assert output_bias.device.type == 'cuda' + assert output.device.type == device_name diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index b35c77b58d..68c6e6b55c 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -1,21 +1,25 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import os import torch import megatron.core.parallel_state as ps +from deepspeed.accelerator import get_accelerator + class Utils: - world_size = torch.cuda.device_count() - rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.getenv("WORLD_SIZE", '1')) + rank = int(os.getenv('LOCAL_RANK', '0')) @staticmethod def initialize_distributed(): print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}') - torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + get_accelerator().set_device(Utils.rank % get_accelerator().device_count()) init_method = 'tcp://' master_ip = os.getenv('MASTER_ADDR', 'localhost') master_port = os.getenv('MASTER_PORT', '6000') init_method += master_ip + ':' + master_port - torch.distributed.init_process_group(backend='nccl', world_size=Utils.world_size, rank=Utils.rank, init_method=init_method) + torch.distributed.init_process_group(backend=get_accelerator().communication_backend_name(), world_size=Utils.world_size, rank=Utils.rank, init_method=init_method) @staticmethod def destroy_model_parallel(): @@ -23,8 +27,8 @@ def destroy_model_parallel(): torch.distributed.barrier() @staticmethod - def initialize_model_parallel(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1, virtual_pipeline_model_parallel_size = None, pipeline_model_parallel_split_rank = None): + def initialize_model_parallel(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1, sequence_parallel_size = 1, virtual_pipeline_model_parallel_size = None, pipeline_model_parallel_split_rank = None): ps.destroy_model_parallel() if not torch.distributed.is_initialized(): Utils.initialize_distributed() - ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) \ No newline at end of file + ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, sequence_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) \ No newline at end of file diff --git a/tools/hf2megads_weight_converter.py b/tools/hf2megads_weight_converter.py index bfbde1fd05..12468963c5 100755 --- a/tools/hf2megads_weight_converter.py +++ b/tools/hf2megads_weight_converter.py @@ -3,9 +3,11 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import print_rank_0, get_tokenizer, get_args from megatron.core import mpu +from megatron.core import tensor_parallel from megatron.core.utils import divide from megatron.model import GPTModelPipe, Float16Module from megatron.utils import unwrap_model @@ -13,20 +15,30 @@ from megatron.arguments import core_transformer_config_from_args from megatron.initialize import initialize_megatron from megatron.optimizer import get_megatron_optimizer -from megatron.checkpointing import save_checkpoint +from megatron.checkpointing import save_checkpoint, load_checkpoint from megatron.training import get_optimizer_param_scheduler from deepspeed.runtime.utils import see_memory_usage import deepspeed +import copy +from pathlib import Path + def add_extra_args(parser): """Text generation arguments.""" group = parser.add_argument_group(title='hf2mega') - group.add_argument("--hf-ckpt-num-shards", type=int, help='num of llama ckpt.') - group.add_argument("--origin-hf-ckpt-dir", + group.add_argument("--hf-ckpt-dir", type=str, default="", - help="the original path of the llama-hf ckpt") + help="the llama-hf ckpt") + group.add_argument("--hf-ckpt-num-shards", type=int, default=-1, help='num of llama ckpt.') + group.add_argument("--load-mode", type=str, + default=None, + choices=['torchbin', 'safetensor', 'auto'], + help="load ckpt format: pytorch.bin or model.safetensor or auto.") + group.add_argument("--to-hf-ckpt", action="store_true", + help="by default convert from hf to megads" + "if set, convert reversely from megads to hf ckpt.") return parser @@ -55,6 +67,49 @@ def load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards): return loaded +def load_and_print_hf_weight_from_safetensor(hf_ckpt_dir, hf_ckpt_num_of_shards): + from safetensors import safe_open + # Optimization point: We can selectively load specific 'shared' data to reduce CPU memory usage. + hf_model = {} + print_rank_0( + f"----------------------------hf weight list----------------------------") + + for wid in range(1, hf_ckpt_num_of_shards + 1): + if hf_ckpt_num_of_shards == 1: + ckpt_path = f"{hf_ckpt_dir}/model.safetensors" + else: + ckpt_path = f"{hf_ckpt_dir}/model-{wid:05d}-of-{hf_ckpt_num_of_shards:05d}.safetensors" + + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + for k in f.keys(): + print_rank_0(f"name: {k}, shape: {f.get_tensor(k).shape}") + assert k not in hf_model + hf_model[k] = f.get_tensor(k).clone() + + return hf_model + + +def load_and_print_hf_weight_auto(hf_ckpt_dir, no_init=True): + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.modeling_utils import no_init_weights + + if no_init: + hf_config = AutoConfig.from_pretrained(hf_ckpt_dir, trust_remote_code=True) + with no_init_weights(): + hf_model = AutoModelForCausalLM.from_config(hf_config, trust_remote_code=True, torch_dtype=torch.bfloat16) + else: + hf_model = {} + hf_auto_model = AutoModelForCausalLM.from_pretrained(hf_ckpt_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) + print_rank_0( + f"----------------------------hf weight list----------------------------") + + for name, param in hf_auto_model.named_parameters(): + hf_model[name] = param.clone() + print_rank_0(name) + + return hf_model + + def print_distinct_weights(model): print_rank_0( f"----------------------------mega-ds weight list----------------------------") @@ -70,16 +125,19 @@ def print_distinct_weights(model): class refactor: - def __init__(self, model, loaded, args, config): + def __init__(self, ds_model, hf_model, args, config): tokenizer = get_tokenizer() # align layer number - self.model = model - self.loaded = loaded + self.ds_model = ds_model + self.hf_model = hf_model + self.hf_dict = {} # for handling pp case when converting mds => hf self.config = config self.offset_num = 2 self.mega_emb_wnum = 1 self.mega_norm_wnum = args.num_layers + 2 + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads self.mega_lm_head_wnum = self.mega_norm_wnum + 1 self.token_vocab = tokenizer.vocab_size self.padded_vocab_size = args.padded_vocab_size @@ -95,7 +153,7 @@ def _embedding_refactor(self, pname, p): hf_name = "lm_head.weight" elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight": hf_name = "model.embed_tokens.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] assert hf_w.shape[0] == self.token_vocab per_partition_vocab_size, start_index, end_index = compute_partition_range( self.padded_vocab_size, self.tp_rank, self.tp_size) @@ -112,24 +170,28 @@ def _embedding_refactor(self, pname, p): ) return new_w + + + def _direct_refactor(self, pname, p, hf_layer=None, subname=None): if pname == f"{self.mega_norm_wnum}.weight": hf_name = "model.norm.weight" elif subname in ["input_layernorm.weight", "post_attention_layernorm.weight"]: hf_name = f"model.layers.{hf_layer}.{subname}" - new_w = hf_w = self.loaded[hf_name] + new_w = hf_w = self.hf_model[hf_name] self.record_mapping_info( f"mega-ds:{pname,p.data.shape}<--hf{hf_name,} {hf_w.shape}") return new_w + def _qkv_refactor(self, pname, p, hf_layer): hf_wq_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight" hf_wk_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight" hf_wv_name = f"model.layers.{hf_layer}.self_attn.v_proj.weight" - wq = self.loaded[hf_wq_name] - wk = self.loaded[hf_wk_name] - wv = self.loaded[hf_wv_name] + wq = self.hf_model[hf_wq_name] + wk = self.hf_model[hf_wk_name] + wv = self.hf_model[hf_wv_name] hidden_size = wq.shape[0] per_partition_size, start_index, end_index = compute_partition_range( @@ -159,8 +221,8 @@ def _qkv_refactor(self, pname, p, hf_layer): def _mlphto4h_dense_refactor(self, pname, p, hf_layer): hf_w_gate_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" hf_w_up_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" - w_gate = self.loaded[hf_w_gate_name] - w_up = self.loaded[hf_w_up_name] + w_gate = self.hf_model[hf_w_gate_name] + w_up = self.hf_model[hf_w_up_name] hidden_size = w_gate.shape[0] per_partition_size, start_index, end_index = compute_partition_range( @@ -184,7 +246,7 @@ def _attn_dense_refactor(self, pname, p, hf_layer, subname): else: hf_name = f"model.layers.{hf_layer}.mlp.down_proj.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] hidden_size = hf_w.shape[1] per_partition_size, start_index, end_index = compute_partition_range( hidden_size, self.tp_rank, self.tp_size) @@ -200,7 +262,7 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname): hf_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" else: hf_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] hidden_size = hf_w.shape[0] per_partition_size, start_index, end_index = compute_partition_range( hidden_size, self.tp_rank, self.tp_size) @@ -212,10 +274,11 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname): ) return new_w - def refactor(self): + def transform_from_hf_to_megds(self): assert self.is_refactored == False new_w = None - for pname, p in self.model.named_parameters(): + for pname, p in self.ds_model.named_parameters(): + if pname in [ f"{self.mega_emb_wnum}.word_embeddings.weight", f"{self.mega_lm_head_wnum}.lm_head.weight" @@ -253,6 +316,123 @@ def refactor(self): new_w = None self.is_refactored = True + + def _embedding_refactor_to_hf(self, pname, ds_w): + if pname == f"{self.mega_lm_head_wnum}.lm_head.weight": + hf_w = self.hf_model.lm_head.weight + hf_w_name = "lm_head.weight" + elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight": + hf_w = self.hf_model.model.embed_tokens.weight + hf_w_name = "model.embed_tokens.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank[:hf_w.shape[0], :]) + + def _direct_refactor_to_hf(self, pname, ds_w, hf_layer=None, subname=None): + if pname in [f"{self.mega_norm_wnum}.weight"]: + hf_w = self.hf_model.model.norm.weight + hf_w_name = "model.norm.weight" + elif subname in ["input_layernorm.weight"]: + hf_w = self.hf_model.model.layers[hf_layer].input_layernorm.weight + hf_w_name = f"model.layers.{hf_layer}.input_layernorm.weight" + elif subname in ["post_attention_layernorm.weight"]: + hf_w = self.hf_model.model.layers[hf_layer].post_attention_layernorm.weight + hf_w_name = f"model.layers.{hf_layer}.post_attention_layernorm.weight" + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w) + + def _attn_dense_refactor_to_hf(self, pname, ds_w, hf_layer, subname): + if subname == "self_attention.dense.weight": + hf_w = self.hf_model.model.layers[hf_layer].self_attn.o_proj.weight + hf_w_name = f"model.layers.{hf_layer}.self_attn.o_proj.weight" + elif subname == "mlp.dense_4h_to_h.weight": + hf_w = self.hf_model.model.layers[hf_layer].mlp.down_proj.weight + hf_w_name = f"model.layers.{hf_layer}.mlp.down_proj.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_last_dim(ds_w) + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank) + + def _mlphto4h_dense_refactor_to_hf(self, pname, ds_w, hf_layer): + hf_g_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" + hf_u_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + ds_w_shape = ds_w_all_rank.shape + ds_w_all_rank = ds_w_all_rank.reshape(self.tp_size, 2, -1, ds_w_shape[-1]) + self.hf_dict[hf_g_name] = copy.deepcopy(ds_w_all_rank[:, 0, :, :].reshape(-1, ds_w_shape[-1])) + self.hf_dict[hf_u_name] = copy.deepcopy(ds_w_all_rank[:, 1, :, :].reshape(-1, ds_w_shape[-1])) + + + def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer): + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + hf_q = self.hf_model.model.layers[hf_layer].self_attn.q_proj.weight + hf_k = self.hf_model.model.layers[hf_layer].self_attn.k_proj.weight + hf_v = self.hf_model.model.layers[hf_layer].self_attn.v_proj.weight + hf_q_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight" + hf_k_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight" + hf_v_name = f"model.layers.{hf_layer}.self_attn.v_proj.weight" + oldshape = hf_q.shape + hidden_size = oldshape[-1] + hidden_size_per_attention_head = divide(hidden_size, + self.config.num_attention_heads) + num_attention_heads_per_partition = divide(self.config.num_attention_heads, + self.tp_size) + newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size) + ds_w_out = ds_w_all_rank.reshape(*newshape) + self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1])) + self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1])) + self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1])) + + + def transform_from_megads_to_hf(self): + use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False + + for pname, p in self.ds_model.named_parameters(): + if pname in [ + f"{self.mega_emb_wnum}.word_embeddings.weight", + f"{self.mega_lm_head_wnum}.lm_head.weight", + ]: + self._embedding_refactor_to_hf(pname, p) + elif pname in [ + f"{self.mega_norm_wnum}.weight", + ]: + self._direct_refactor_to_hf(pname, p) + else: + mobj = self.decoder_pat.match(pname) + layer_num = int(mobj.group(1)) + subname = mobj.group(2) + hf_layer = layer_num - self.offset_num + if subname in ["self_attention.query_key_value.weight"]: + if not use_gqa: + self._qkv_refactor_to_hf(pname, p, hf_layer) + else: + #TODO(billishyahao): Not impl yet ... + assert False + elif subname in ["mlp.dense_h_to_4h.weight"]: + self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer) + elif subname in [ + "self_attention.dense.weight", + "mlp.dense_4h_to_h.weight" + ]: + self._attn_dense_refactor_to_hf(pname, p, hf_layer, subname) + elif subname in [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + ]: + self._direct_refactor_to_hf(pname, p, hf_layer, subname) + else: + print(f"Unrecognized weight type: {pname}") + raise ValueError(f"Unrecognized weight type: {pname}") + self.is_refactored = True + def record_mapping_info(self, record_msg): self.refactor_weight_list.append(record_msg) @@ -272,7 +452,18 @@ def inorder_show_record(self): torch.distributed.barrier() -def convert_hf_to_mega_ds(): +def load_hf_weights(args, no_init): + if args.load_mode == 'torchbin': + assert no_init == False, "only work with init" + return load_and_print_hf_weight(args.hf_ckpt_dir, args.hf_ckpt_num_shards) + elif args.load_mode == 'safetensor': + assert no_init == False, "only work with init" + return load_and_print_hf_weight_from_safetensor(args.hf_ckpt_dir, args.hf_ckpt_num_shards) + elif args.load_mode == 'auto': + return load_and_print_hf_weight_auto(args.hf_ckpt_dir, no_init) + + +def convert_ckpt(): """Build the model.""" args = get_args() print_rank_0(f'building model ...') @@ -286,49 +477,74 @@ def convert_hf_to_mega_ds(): enabled=args.zero_stage == 3, mpu=mpu): if args.deepspeed and not args.no_pipeline_parallel: - model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True) + ds_model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True) else: raise NotImplementedError("Not implemented") see_memory_usage(f"After Building Model", force=True) if torch.distributed.get_rank() < 2: - print(f"{torch.distributed.get_rank()} {model}") - - # load and initialize HF weight dict - # print hf weights list & mega-ds weights list - hf_ckpt_dir = args.origin_hf_ckpt_dir - hf_ckpt_num_of_shards = args.hf_ckpt_num_shards - loaded = load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards) - print_distinct_weights(model) - - # refactor weight from hf to mega-ds - - cur_refactor = refactor(model, loaded, args, config) - cur_refactor.refactor() - cur_refactor.inorder_show_record() + print(f"{torch.distributed.get_rank()} {ds_model}") - del loaded + # 'torchbin', 'safetensor', 'auto' + hf_model = load_hf_weights(args, no_init=args.to_hf_ckpt) - unwrapped_model = unwrap_model([model], (torchDDP, LocalDDP, Float16Module)) - optimizer = get_megatron_optimizer(unwrapped_model) - opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + # print_distinct_weights(hf_model) #init model and save print_rank_0(f"before deepspeed init") ds_engine, _, _, _ = deepspeed.initialize( - model=model, - optimizer=optimizer, + model=ds_model, + optimizer=None, args=args, - lr_scheduler=opt_param_scheduler, + lr_scheduler=None, mpu=mpu if args.no_pipeline_parallel else None) print_rank_0(f"after deepspeed init") - print_rank_0(f"mega-ds checkpoint will be saved in {args.save}") - save_checkpoint(0, [ds_engine], optimizer, opt_param_scheduler) - print_rank_0(f"save checkpoint completed") + if args.to_hf_ckpt: + load_checkpoint([ds_engine], None, None, load_only_weights=True) + print_rank_0(f"completed to load deepspeed actual checkpoint") + + # refactor weight from hf to mega-ds and vice versa + + cur_refactor = refactor(ds_model, hf_model, args, config) + if args.to_hf_ckpt: + cur_refactor.transform_from_megads_to_hf() + else: + cur_refactor.transform_from_hf_to_megds() + # cur_refactor.inorder_show_record() + + if args.to_hf_ckpt: + save_path = args.save + if not os.path.exists(save_path): + Path(save_path).mkdir(parents=True, exist_ok=True) + ckpt_per_pp_path = os.path.join(save_path, f"model_pp{mpu.get_pipeline_model_parallel_rank()}.pt") + torch.save(cur_refactor.hf_dict, ckpt_per_pp_path) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(f"hf checkpoint will be saved in {save_path}/release ") + if mpu.is_pipeline_last_stage(): + ## doing checkpoint merging and saving... + # hf_model.tie_weights() + + all_wei = {} + for pprank in range(mpu.get_pipeline_model_parallel_world_size()): + ckpt_per_pp_path = os.path.join(save_path, f"model_pp{pprank}.pt") + partial_wei = torch.load(ckpt_per_pp_path) + all_wei = all_wei | partial_wei + + hf_model.load_state_dict(all_wei) + + # mega-ds checkpoint will be saved in args.save + hf_model.save_pretrained(os.path.join(save_path, "release"), safe_serialization=True) + else: + print_rank_0(f"mega-ds checkpoint will be saved in {args.save}") + save_checkpoint(0, [ds_engine], None, None) + + print_rank_0(f"save checkpoint completed") if __name__ == "__main__": initialize_megatron(extra_args_provider=add_extra_args) - convert_hf_to_mega_ds() + convert_ckpt() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 399f93c10e..6e117db31a 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Processing large data for pretraining.""" @@ -193,10 +194,15 @@ def get_args(): group.add_argument('--tokenizer-type', type=str, required=True, choices=['BertWordPieceLowerCase','BertWordPieceCase', 'GPT2BPETokenizer', 'SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', 'NullTokenizer'], + 'GPTSentencePieceTokenizer', 'HFTokenizer', + 'NullTokenizer'], help='What type of tokenizer to use.') group.add_argument('--tokenizer-model', type=str, default=None, help='YTTM tokenizer model.') + group.add_argument('--seq-length', type=int, default=None, + help='Maximum sequence length to process.') + group.add_argument('--trust-remote-code', action='store_true', + help='To run HFTokenizer model from local path.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') group.add_argument('--vocab-size', default=786, @@ -229,7 +235,7 @@ def get_args(): print("Are you sure you don't want to split sentences?") # some default/dummy values for the tokenizer - args.rank = 1 + args.rank = 0 args.make_vocab_size_divisible_by = 128 args.tensor_model_parallel_size = 1 args.vocab_extra_ids = 0 diff --git a/train_aGPT_7B.sh b/train_aGPT_7B.sh index 8bdf776ed9..286740fc89 100644 --- a/train_aGPT_7B.sh +++ b/train_aGPT_7B.sh @@ -7,18 +7,29 @@ # AuroraGPT-7B @ ALCF ##################################### - # 1. Navigate into `$PBS_O_WORKDIR` cd "${PBS_O_WORKDIR}" || exit HERE=$(python3 -c 'import os; print(os.getcwd())') && export HERE + # 2. source `ALCF/helpers.sh` source "${HERE}/ALCF/helpers.sh" || exit + # 3. call `setup` from `./ALCF/helpers.sh` setup "$@" || exit -export run_cmd="${run_cmd}" -echo "${run_cmd}" | tee -a "${OUTPUT_LOG}" -# 7. Tell user where to find output +# export run_cmd="${run_cmd}" +echo "${run_cmd[@]}" | tee -a "${OUTPUT_LOG}" + +# 4. Tell user where to find output printf "[!! %s] View output at:\n %s\n" "$(printBlue "NOTE")" "$(printYellow "${OUTPUT_LOG}")" | tee -a "${OUTPUT_LOG}" -XPU_IGNORE_STRING="CCL_WARN|\ -\ INFO\ \-\ |real_accelerator\.py|numexpr\.utils|async_io|libaio" -# 8. Evaluate ${run_cmd} and append outputs to ${OUTPUT_LOG} -eval "${run_cmd}" |& grep -E -v "${XPU_IGNORE_STRING}" |& tee -a "${OUTPUT_LOG}" + +# # 5. Ignore the following strings on Intel XPU devices +# # (otherwise they'll clutter up logs) +# XPU_IGNORE_STRING="CCL_WARN|\ -\ INFO\ \-\ |real_accelerator\.py|numexpr\.utils|async_io|libaio" + +# if [[ $(ezpz_get_machine_name) == "aurora" ]]; then +# module unload mpich && module load mpich +# fi +# +# 6. Evaluate ${run_cmd} and append outputs to ${OUTPUT_LOG} +# eval "${run_cmd[@]}" |& tee -a "${OUTPUT_LOG}" +eval "${run_cmd[*]}" |& tee -a "${OUTPUT_LOG}"