diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index bc31c322ce..fac2f4d8ec 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -9,7 +9,7 @@ # ```bash # $ git clone https://github.com/argonne-lcf/Megatron-DeepSpeed # $ cd Megatron-DeepSpeed -# $ export PBS_O_WORKDIR=$(pwd) && source ALCF/helpers.sh && ezpz_setup +# $ export PBS_O_WORKDIR=$(pwd) && source ALCF/helpers.sh && setup # ``` # # and this will, automatically: @@ -120,14 +120,15 @@ setup() { # Create `deepspeed_config.json` from runtime params from ^ buildDSconfig || exit # Specify output directory for {logs, checkpoints, etc.} + setup_checkpoint || exit setOutput || exit # Specify additional `deepspeed` arguments (dependent on _newly created_ variables) set_args || exit # Ensure executable exists in expected path check_executable "${EXEC:-${WORKING_DIR}/pretrain_gpt_alcf.py}" - dfl="${DATA_FILE_LIST:-}" + dfl="${DATA_FILE_LIST:-"${PBS_O_WORKDIR}/ALCF/data-lists/$(get_machine_name)/dolma.txt"}" # Setup data + tokenizer via `DATA_FILE_LIST` and `TOKENIZER_TYPE` - tok="${TOKENIZER_TYPE:-Llama2}" + tok="${TOKENIZER_TYPE:-Llama2Tokenizer}" setup_tokenizer_and_data "${tok}" "${dfl}" || exit make_data || exit # Print job info @@ -150,7 +151,8 @@ setup_run_cmd() { # take in additional arguments # and append them directly to # the end of the `run_cmd` - custom_args="$@" + # custom_args="$@" + custom_args=("$@") ############################## #### Make it easy to track experiments by date ################### year="$(date "+%Y")" @@ -168,78 +170,113 @@ setup_run_cmd() { # `export LAUNCH_WITH=deepspeeed && bash train_llama_alcf.sh` ################################################################## setupLauncher "${LAUNCH_WITH:-MPICH}" || exit - TBDIR="${CKPT_DIR}/tensorboard" - mkdir -p "${TBDIR}" export data_cache_path="${CKPT_DIR}/${DATA_CACHE_PATH}" && mkdir -p "${data_cache_path}" printf "\n" echo "Using data_cache_path: ${data_cache_path}" - export DEFAULTS="\ - --split 100,0,0 \ - --log-interval 1 \ - --no-bias-gelu-fusion \ - --no-bias-dropout-fusion \ - --no-masked-softmax-fusion \ - --no-gradient-accumulation-fusion \ - --accumulate-allreduce-grads-in-fp32 \ - --log-timers-to-tensorboard \ - --log-optimizer-states-to-tensorboard" - OVERRIDE_CKPT_OPT_PARAM="${OVERRIDE_CKPT_OPT_PARAM:-}" - if [[ -z "${OVERRIDE_CKPT_OPT_PARAM:-}" ]]; then - DEFAULTS="${DEFAULTS} --use-checkpoint-opt_param-scheduler" - fi - if [[ "${SP}" -ge 2 ]]; then - export DEFAULTS="${DEFAULTS} --ds-sequence-parallel-size ${SP} --force-ds-sequence-parallel" - fi ################################################################## # WARN: to disable Llama-type architectures, toggle via: # `NO_LLAMA=1 bash train_llama_alcf.sh` ################################################################## if [[ -z "${NO_LLAMA:-}" ]]; then - llama_flags="${LLAMA_ARGS}\ - --num-key-value-heads ${NUM_KV_HEAD} \ - --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ - " - else - echo "!! Running in NO_LLAMA MODE !!" - llama_flags="" + llama_flags=( + "--swiglu" + "--hidden-dropout 0" + "--attention-dropout 0" + "--normalization rmsnorm" + "--disable-bias-linear" + "--no-query-key-layer-scaling" + "--use-rotary-position-embeddings" + "--untie-embeddings-and-output-weights" + "--num-key-value-heads ${NUM_KV_HEAD}" + "--ffn-hidden-size ${FFN_HIDDEN_SIZE}" + ) + fi + # min_lr=$(python3 -c 'print(f"{2 / (10 ** 5):.8f}")') + # "--min-lr ${LR:-${min_lr}}" # 2e-5 + # "--min-lr ${MIN_LR:-"2e-6"}" # 2e-5 + lr_flags=( + "--lr ${LR:-0.0002}" + "--lr-decay-style ${LR_DECAY_STYLE:-cosine}" + "--lr-warmup-fraction ${LR_WARMUP_FRAC:-0.05}" + ) + if [[ -n "${LR_DECAY_ITERS:-}" ]]; then + lr_flags+=("--lr-decay-iters ${LR_DECAY_ITERS:-}") + fi + + tb_flags=() + if [[ -z "${NO_TENSORBOARD:-}" ]]; then + TBDIR="${CKPT_DIR}/tensorboard" + mkdir -p "${TBDIR}" + tb_flags+=( + "--log-timers-to-tensorboard" + "--log-optimizer-states-to-tensorboard" + "--tensorboard-dir ${TBDIR}" + ) fi - export run_cmd=" - ${LAUNCHER} \ - --${DTYPE} \ - ${DEFAULTS} \ - --optimizer ${OPT} \ - --adam-beta1=${ADAM_BETA1} \ - --adam-beta2=${ADAM_BETA2} \ - --adam-eps=${ADAM_EPS} \ - --weight-decay=${WEIGHT_DECAY} \ - --save ${CKPT_DIR} \ - --load ${CKPT_DIR} \ - --seq-length ${SEQ} \ - --num-layers ${NLAYERS} \ - --hidden-size ${HIDDEN} \ - --tensorboard-dir ${TBDIR} \ - --train-iters ${TRAIN_ITERS} \ - --eval-iters ${EVAL_ITERS} \ - --distributed-backend ${BE} \ - --num-attention-heads ${HEADS} \ - --save-interval ${SAVE_INTERVAL} \ - --eval-interval ${EVAL_INTERVAL} \ - --max-position-embeddings ${SEQ} \ - --micro-batch-size ${MICRO_BATCH} \ - --tensor-model-parallel-size ${TP} \ - --global-batch-size ${GLOBAL_BATCH} \ - --pipeline-model-parallel-size ${PP} \ - --data-cache-path ${data_cache_path} \ - ${DATA_FLAGS} \ - ${LR_ARGS} \ - ${llama_flags} \ - ${FLASH_ARG} \ - ${TIMING_STR} \ - ${TOKENIZER_FLAGS} \ - ${ds_args} \ - ${gpt_args[*]} \ - ${custom_args} - " + dfl_fallback="${DATA_FILE_LIST:-${PBS_O_WORKDIR}/ALCF/data-lists/$(get_machine_name)/dolma.txt}" + + train_args=() + if [[ -z "${OVERRIDE_CKPT_OPT_PARAM:-}" ]]; then + train_args+=("--use-checkpoint-opt_param-scheduler") + fi + # "--init-method-std ${INIT_METHOD_STD:-0.0006}" + # "--shuffle-sample" + train_args+=( + "${lr_flags[@]}" + "${custom_args[@]}" + "${llama_flags[@]}" + "${DATA_FLAGS}" + "${FLASH_ARG}" + "${TIMING_STR}" + "${TOKENIZER_FLAGS}" + "${tb_flags[@]}" + "${ds_args[@]}" + "${gpt_args[@]}" + "--${DTYPE}" + "--shuffle-sample-in-corpus" + "--blend-sample-in-corpus" + "--accumulate-allreduce-grads-in-fp32" + "--no-bias-gelu-fusion" + "--no-bias-dropout-fusion" + "--no-masked-softmax-fusion" + "--no-gradient-accumulation-fusion" + "--optimizer=${OPT}" + "--tensor-model-parallel-size=${TP}" + "--pipeline-model-parallel-size=${PP}" + "--max-position-embeddings=${SEQ}" + "--micro-batch-size=${MICRO_BATCH}" + "--ds-sequence-parallel-size=${SP}" + "--global-batch-size=${GLOBAL_BATCH}" + "--split=${TRAIN_SPLIT:-990},${VAL_SPLIT:-10},${TEST_SPLIT:-0}" + "--timing-log-level=${TIMING_LOG_LEVEL:-1}" + "--eval-interval=${EVAL_INTERVAL:-100}" + "--eval-iters=${EVAL_ITERS:-20}" + "--save-interval=${SAVE_INTERVAL:-50}" + "--log-interval=${LOG_INTERVAL:-1}" + "--save=${SAVE:-${CKPT_DIR}}" + "--load=${LOAD:-${CKPT_DIR}}" + "--seq-length=${SEQ}" + "--num-layers=${NLAYERS}" + "--hidden-size=${HIDDEN}" + "--train-iters=${TRAIN_ITERS}" + "--distributed-backend=${BE}" + "--weight-decay=${WEIGHT_DECAY:-0.1}" + "--adam-beta1=${ADAM_BETA1:-0.9}" + "--adam-beta2=${ADAM_BETA2:-0.95}" + "--adam-eps=${ADAM_EPS:-0.00001}" + "--clip-grad=${CLIP_GRAD:-1.0}" + "--num-attention-heads=${HEADS}" + "--data-cache-path=${data_cache_path}" + "--data-file-list=${DATA_FILE_LIST:-${dfl_fallback}}" + ) + # "--adam-eps ${ADAM_EPS:-0.00001}" + cache_dir="${PBS_O_WORKDIR}/.cache/" + mkdir -p "${cache_dir}" + targs_cache="${cache_dir}/train_args.txt" + for arg in "${train_args[@]}"; do echo "${arg}" >>"${targs_cache}"; done + export TRAIN_ARGS=("$(printf '%s\n' "${train_args[@]}" | sort)") + printf "Training Arguments: %s\n" "${TRAIN_ARGS[@]}" + export run_cmd=("${LAUNCHER}" "${train_args[@]}") } save_dotenv() { @@ -383,17 +420,20 @@ setupLauncher() { printf " %s" "$(printMagenta "${LAUNCHER}")" } -set_lr_args() { - LR_ARGS="--lr ${LR} --lr-decay-style cosine" - if [[ -n "${LR_DECAY_ITERS:-}" ]]; then - LR_ARGS="${LR_ARGS} --lr-decay-iters ${LR_DECAY_ITERS}" - fi - if [[ -n "${LR_WARMUP_FRAC}" ]]; then - LR_ARGS="${LR_ARGS} --lr-warmup-fraction ${LR_WARMUP_FRAC}" - fi - echo "LR_ARGS: ${LR_ARGS}" - export LR_ARGS="${LR_ARGS}" -} +# set_lr_args() { +# export LR=${LR:-0.0002} # LEARNING_RATE +# export LR_WARMUP_FRAC=${LR_WARMUP_FRAC:-0.05} # LEARNING RATE WARMUP +# export LR_DECAY_ITERS=${LR_DECAY_ITERS:-} # LR DECAY ITERS +# LR_ARGS="--lr ${LR} --lr-decay-style cosine" +# if [[ -n "${LR_DECAY_ITERS:-}" ]]; then +# LR_ARGS="${LR_ARGS} --lr-decay-iters ${LR_DECAY_ITERS}" +# fi +# if [[ -n "${LR_WARMUP_FRAC}" ]]; then +# LR_ARGS="${LR_ARGS} --lr-warmup-fraction ${LR_WARMUP_FRAC}" +# fi +# echo "LR_ARGS: ${LR_ARGS}" +# export LR_ARGS="${LR_ARGS}" +# } ######################################################################### # `get_batch_size_on_polaris`: Identify MICRO_BATCH to use on Polaris. @@ -448,12 +488,14 @@ _get_num_hosts_from_hostfile() { # # [2 tiles] x [6 xpus / tile] = 12 xpus # -# | nnhosts | nhosts | GAS | -# |:-------------:|:---------:|:-----:| -# | 64 <= n < inf | [64, inf) | 1 | -# | 32 <= n < 64 | [32, 64) | 2 | -# | 16 <= n < 32 | [16, 32) | 4 | -# | 0 <= n < 16 | [0, 16) | 8 | +# | nnhosts | nhosts | GAS | +# |:---------------:|:----------:|:-----:| +# | 256 <= n < inf | [256, inf) | 1 | +# | 128 <= n < 256 | [128, 256) | 2 | +# | 32 <= n < 128 | [32, 128) | 4 | +# | 16 <= n < 32 | [16, 32) | 8 | +# | 0 <= n < 16 | [0, 16) | 16 | +# ########################################### get_grad_acc_steps_on_aurora() { if [[ "$#" == 0 ]]; then @@ -461,18 +503,21 @@ get_grad_acc_steps_on_aurora() { elif [[ "$#" == 1 ]]; then hf="$1" else + echo "Usage: get_grad_acc_steps_on_aurora" echo "Expected exactly 0 or 1 arguments, received: $#" exit 1 fi nhosts=$(wc -l <"${hf}") - if [[ 64 -le "${nhosts}" ]]; then + if [[ "${nhosts}" -gt 256 ]]; then gas=1 - elif [[ 32 -le "${nhosts}" && "${nhosts}" -lt 64 ]]; then + elif [[ 128 -le "${nhosts}" && "${nhosts}" -lt 256 ]]; then gas=2 - elif [[ 16 -le "${nhosts}" && "${nhosts}" -lt 32 ]]; then + elif [[ 32 -le "${nhosts}" && "${nhosts}" -lt 128 ]]; then gas=4 - else + elif [[ 16 -le "${nhosts}" && "${nhosts}" -lt 32 ]]; then gas=8 + else + gas=16 fi echo "${gas}" } @@ -518,14 +563,13 @@ set_ccl_vars_on_aurora() { ############################################################################## setParams() { FLASH_ARG="" - LLAMA_ARGS="--attention-dropout 0 --hidden-dropout 0" # ---- [Parallelism Settings] -------------------------------------------+ # ------ [Aurora] -------||------ [SunSpot] ------------- # if [[ $(hostname) == x4* || $(hostname) == x1* ]]; then mn=$(get_machine_name) if [[ "${mn}" == "aurora" || "${mn}" == "sunspot" ]]; then TP=${TP:-1} # TP = 1 - export SAVE_INTERVAL="${SAVE_INTERVAL:-20}" + export SAVE_INTERVAL="${SAVE_INTERVAL:-50}" export CCL=${CCL:-ccl} # CCL export BE="${CCL}" # COMMUNICATION BACKEND = CCL export DTYPE=${DTYPE:-bf16} # DTYPE: bf16 @@ -534,7 +578,7 @@ setParams() { export GRAD_ACC_STEPS="${GRAD_ACC_STEPS:-${gas}}" # export GRAD_ACC_STEPS="${GRAD_ACC_STEPS:-$(get_grad_acc_steps_on_aurora "$@)}" echo "[setParams] Using GRAD_ACC_STEPS: ${GRAD_ACC_STEPS}" - MICRO_BATCH=${MICRO_BATCH:-4} # MICRO_BATCH = 4 + MICRO_BATCH=${MICRO_BATCH:-1} # MICRO_BATCH = 4 #### [sam: 08/17/2024] ########################################## # Use best set of CCL env vars from Gordon Bell runs on Aurora set_ccl_vars_on_aurora @@ -558,9 +602,7 @@ setParams() { echo "Using flash-attn !!" FLASH_ARG="--use-flash-attn-builder" fi - ###################################################################### - # +--------[Polaris]-----------------------------------+ - # elif [[ $(hostname) == x3* ]]; then + # [Polaris] elif [[ "${mn}" == "polaris" || "${mn}" == "sirius" ]]; then # export LAUNCH_CMD="${LAUNCH_CMD:-deepspeed}" TP=${TP:-1} # TP = 2 @@ -579,30 +621,25 @@ setParams() { fi echo "Setting up AWS NCCL OFI Plugin on Polaris..." source "${WORKING_DIR}/ALCF/aws_ofi_nccl_plugin.sh" || exit - # +--------[Perlmutter]---------------------------------+ - # elif [[ $(hostname) == login* || $(hostname) == nid* ]]; then + # [Perlmutter] elif [[ "${mn}" == login* || "${mn}" == nid* ]]; then TP="${TP:-2}" export NCCL="${NCCL:-nccl}" export BE="${NCCL}" export DTYPE="${DTYPE:-bf16}" - MICRO_BATCH="${MICRO_BATCH:-8}" + MICRO_BATCH="${MICRO_BATCH:-1}" if [[ -n "${NO_FLASH_ATTN-}" ]]; then echo "Not using flash-attn!!" else FLASH_ARG="--use-flash-attn-v2" fi fi - # +----------------------------------------------------------------------+ export TP="${TP}" export PP="${PP:-1}" export SP="${SP:-1}" export FLASH_ARG="${FLASH_ARG}" export DTYPE="${DTYPE:-bf16}" export OPT="${OPT:-adamw}" - export ADAM_BETA1="${ADAM_BETA1:-0.9}" - export ADAM_BETA2="${ADAM_BETA2:-0.95}" - export ADAM_EPS="${ADAM_EPS:-0.00001}" # 1 * 10^{-5} export WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" export HOSTFILE="${HOSTFILE:-${PBS_NODEFILE}}" NHOSTS=$(wc -l <"${HOSTFILE}") @@ -621,18 +658,19 @@ setParams() { # +---[Run Settings]------------------------------------------------------+ export SEQ=${SEQ:-4096} # SEQ_LEN: 4096 export ZERO_STAGE=${ZERO_STAGE:-1} # ZERO OFFLOADING STAGE - export MICRO_BATCH=${MICRO_BATCH:-8} # MICRO BATCH SIZE + export MICRO_BATCH=${MICRO_BATCH:-1} # MICRO BATCH SIZE export GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} # GRADIENT ACCUMULATION STEPS - export EVAL_ITERS="${EVAL_ITERS:-10}" # NUMBER OF EVAL ITERS TO RUN - export EVAL_INTERVAL="${EVAL_INTERVAL:-50000}" # HOW FREQUENTLY TO RUN EVAL - export SAVE_INTERVAL=${SAVE_INTERVAL:-50} # HOW FREQUENTLY TO SAVE CKPTS export TIMING_LOG_LEVEL="${TIMING_LOG_LEVEL:-1}" # TIMING VERBOSITY IN LOGS export ACT_CKPT_NUM_LAYERS="${ACT_CKPT_NUM_LAYERS:-1}" # NUM LAYERS TO CHECKPOINT ACTIVATIONS - export USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING:-1} # USE ACTIVATION CHECKPOINTING ? + export USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING:-} # USE ACTIVATION CHECKPOINTING ? export GLOBAL_BATCH_MAX=$((WORLD_SIZE * MICRO_BATCH * GRAD_ACC_STEPS / TP / PP / SP)) # MAX GLOBAL BATCH SIZE export GLOBAL_BATCH="${GLOBAL_BATCH:-${GLOBAL_BATCH_MAX}}" # WILL USE MAX IF NOT SET IN ENVIRONMENT - # export TRAIN_ITER=${TRAIN_ITER:-317892} # NUMBER OF TRAIN ITERS - if [[ -z "${TRAIN_ITERS:-${TRAIN_ITER:-}}" ]]; then + if [[ -n "${TRAIN_TOKENS:-}" ]]; then + export TRAIN_TOKENS="${TRAIN_TOKENS}" + export TRAIN_ITERS=$((TRAIN_TOKENS / SEQ / GLOBAL_BATCH)) + printf "TRAIN_TOKENS=%s (=%sB tokens)\n" "${TRAIN_TOKENS}" "$((TRAIN_TOKENS / 10 ** 9))" + printf "TRAIN_ITERS=%s\n" "${TRAIN_ITERS}" + elif [[ -z "${TRAIN_ITERS:-${TRAIN_ITER:-}}" ]]; then export TRAIN_TOKENS=${TRAIN_TOKENS:-2000000000000} export TRAIN_ITERS=$((TRAIN_TOKENS / SEQ / GLOBAL_BATCH)) printf "TRAIN_TOKENS=%s (=%sB tokens)\n" "${TRAIN_TOKENS}" "$((TRAIN_TOKENS / 10 ** 9))" @@ -648,27 +686,20 @@ setParams() { # # For this reason, we only use the default LLAMA_ARGS when SP=0. ########################################################################## - if [[ "${SP}" == 1 ]]; then - export LLAMA_ARGS="${LLAMA_ARGS} --no-query-key-layer-scaling --use-rotary-position-embeddings --untie-embeddings-and-output-weights --swiglu --normalization rmsnorm --disable-bias-linear" - else - export LLAMA_ARGS="" - echo "NOT USING ROTARY EMBEDDINGS! LLAMA_ARGS=${LLAMA_ARGS}" - fi + # # -----[Learning Rate Settings]-------------------------------------------- + # export LR=${LR:-0.0002} # LEARNING_RATE + # export LR_WARMUP_FRAC=${LR_WARMUP_FRAC:-0.05} # LEARNING RATE WARMUP + # export LR_DECAY_ITERS=${LR_DECAY_ITERS:-} # LR DECAY ITERS + # set_lr_args # -----[Learning Rate Settings]-------------------------------------------- - export LR=${LR:-0.0003} # LEARNING_RATE - export LR_WARMUP_FRAC=${LR_WARMUP_FRAC:-0.05} # LEARNING RATE WARMUP - export LR_DECAY_ITERS=${LR_DECAY_ITERS:-} # LR DECAY ITERS - set_lr_args - # -----[Learning Rate Settings]-------------------------------------------- - if [[ "${TIMING_LOG_LEVEL}" -ge 1 ]]; then - TIMING_STR="\ - --timing-log-level ${TIMING_LOG_LEVEL} \ - --log-timers-to-tensorboard \ - --log-optimizer-states-to-tensorboard \ - " - else - TIMING_STR="" - fi + # # if [[ "${TIMING_LOG_LEVEL:-1}" -gt 1 ]]; then + # if [[ "${TIMING_LOG_LEVEL:-1}" -gt 1 ]]; then + # TIMING_STR="\ + # --timing-log-level ${TIMING_LOG_LEVEL}" + # # " + # else + # TIMING_STR="" + # fi } ############################################## @@ -679,19 +710,31 @@ setParams() { ############################################## set_args() { # ---- Set DeepSpeed arguments -------------------------------- - ds_args=" " - ds_args=" --deepspeed ${ds_args}" - if [[ $PP == 1 ]]; then - ds_args=" --no-pipeline-parallel ${ds_args}" + ds_args=( + "--deepspeed" + ) + if [[ "${PP:-1}" == 1 ]]; then + ds_args+=("--no-pipeline-parallel") fi - ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" - ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" + ds_args+=("--deepspeed_config=${DS_CONFIG}") + ds_args+=("--zero-stage=$ZERO_STAGE") if [[ "${ZERO_STAGE}" == 3 ]]; then - ds_args="--use-mics ${ds_args}" + ds_args+=("--use-mics") fi - if [[ "$USE_ACTIVATION_CHECKPOINTING" == 1 ]]; then + # ds_args=" " + # ds_args=" --deepspeed ${ds_args}" + # if [[ $PP == 1 ]]; then + # ds_args=" --no-pipeline-parallel ${ds_args}" + # fi + # ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" + # ds_args="--zero-stage=$ZERO_STAGE ${ds_args}" + # if [[ "${ZERO_STAGE}" == 3 ]]; then + # ds_args="--use-mics ${ds_args}" + # fi + if [[ -n "${USE_ACTIVATION_CHECKPOINTING:-}" ]]; then echo "!! Caught USE_ACTIVATION_CHECKPOINTING=${USE_ACTIVATION_CHECKPOINTING} !!" - ds_args=" --deepspeed-activation-checkpointing ${ds_args}" + ds_args+=("--deepspeed-activation-checkpointing") + # ds_args=" --deepspeed-activation-checkpointing ${ds_args}" # --checkpoint-activations \ # --deepspeed-activation-checkpointing fi @@ -804,7 +847,8 @@ get_output_prefix() { pre="${pre}_sp${SP}_pp${PP}_tp${TP}_${DTYPE}_opt${OPT}" pre="${pre}_lr${LR}_lwf${LR_WARMUP_FRAC}" if [[ -n "${TOKENIZER_TYPE:-}" ]]; then - pre="${pre}_tok${TOKENIZER_TYPE}" + _tok=$(echo "${TOKENIZER_TYPE}" | sed 's/Tokenizer//g') # noqa + pre="${pre}_tok${_tok}" fi if [[ -n "${LR_DECAY_ITERS}" ]]; then pre="${pre}_ldi${LR_DECAY_ITERS}" @@ -822,9 +866,21 @@ setOutput() { OUTPUT_DIR="logs/${OUTPUT_PREFIX}/$(date +%Y%m%d-%H%M%S)_${WORLD_SIZE}_${HOSTNAME}" export OUTPUT_DIR="${OUTPUT_DIR}" && mkdir -p "${OUTPUT_DIR}" export OUTPUT_LOG="${OUTPUT_DIR}/output.log" - export CKPT_DIR="checkpoints/${OUTPUT_PREFIX}" echo "${OUTPUT_LOG}" >>"logs/latest" printf "\n Please see logs at: %s\n" "$(printGreen "${OUTPUT_DIR}")" +} + +get_checkpoint_dir() { + if [[ -n "${CKPT_DIR:-}" ]]; then + echo "${CKPT_DIR}" + else + echo "checkpoints/$(get_output_prefix)" + fi +} + +setup_checkpoint() { + ckpt_dir=$(get_checkpoint_dir) + export CKPT_DIR="${ckpt_dir}" printf "Checkpoints will be saved to: %s\n" "$(printYellow "${CKPT_DIR}")" } @@ -832,12 +888,13 @@ setOutput() { # Build DeepSpeed config and write to .json ############################################# buildDSconfig() { - export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}" + # export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}" export DS_CONFIG="${WORKING_DIR}/ds-configs/ds_stage${ZERO_STAGE}_mb${MICRO_BATCH}_gb${GLOBAL_BATCH}_pp${PP}_${DTYPE}.json" mkdir -p "$(dirname "${DS_CONFIG}")" echo "DS_CONFIG: ${DS_CONFIG}" printf "ZS: %s, MB: %s, GB: %s, PP: %s, DTYPE: %s" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}" generateDSconfig "${DS_CONFIG}" + cat "${DS_CONFIG}" | jq . } ############################################################################### @@ -893,31 +950,6 @@ install_dependencies() { fi } -###################################################################### -# install_deepspeed_for_xpu -# -# Install microsoft/DeepSpeed on PVC -# -# This will: -# 1. Clone rep -# 2. Checkout appropriate branch -# 3. Install into virtual environment -###################################################################### -install_deepspeed_for_xpu() { - # python3 -m pip install "torch==2.1.0.post2" torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 oneccl_bind_pt==2.1.300+xpu --extra-index-url "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" - echo "Building + Installing DeepSpeed on $(hostname)" - outdir="${WORKING_DIR}/deps/DeepSpeed" - mkdir -p "${outdir}" - git clone https://github.com/microsoft/DeepSpeed.git "${outdir}" - cd "${outdir}" || exit - echo "[install_deepspeed_for_xpu] !! pwd: $(pwd)" - python3 -m pip install --require-virtualenv -r requirements/requirements.txt 1>/dev/null - python3 -m pip install xgboost "numpy<2" --force-reinstall --upgrade --require-virtualenv 1>/dev/null - python setup.py develop 1>/dev/null - cd "${WORKING_DIR}" - echo "[install_deepspeed_for_xpu] !! pwd: $(pwd)" -} - ################################################# # Fix for distributed key value store on Aurora ################################################# @@ -1003,9 +1035,11 @@ setup_tokenizer_and_data() { fi echo "Setting up tokenizer with ${tok}" echo "Using data_file_list: ${dfl}" + _data_flags=() + _tokenizer_flags=() if [[ ${tok} == gpt* || ${tok} == GPT* ]]; then export TOKENIZER_TYPE="GPT2" - export TOKENIZER_FLAGS="--tokenizer-type GPT2BPETokenizer" + _tokenizer_flags+=("--tokenizer-type GPT2BPETokenizer") machine=$(get_machine_name) if [[ ${machine} == "polaris" ]]; then export DATA_PARENT="${DATA_PARENT:-/eagle/argonne_tpc/foremans/projects/argonne-lcf/Megatron-DeepSpeed/dataset}" @@ -1019,18 +1053,25 @@ setup_tokenizer_and_data() { export VOCAB_FILE="${DATA_PARENT}/gpt2-vocab.json" export MERGE_FILE="${DATA_PARENT}/gpt2-merges.txt" export DATA_PATH="${DATA_PARENT}/BookCorpusDataset_text_document" - export DATA_FLAGS="--data-path ${DATA_PATH} --vocab-file ${VOCAB_FILE} --merge-file ${MERGE_FILE}" + _data_flags+=( + "--data-path ${DATA_PATH}" + "--vocab-file ${VOCAB_FILE}" + "--merge-file ${MERGE_FILE}" + ) else - export DATA_FLAGS="" - export TOKENIZER_TYPE="Llama2" + export TOKENIZER_TYPE="${TOKENIZER_TYPE:-Llama2Tokenizer}" tm="${WORKING_DIR}/ALCF/tokenizer.model" # fallback: Megatron-DeepSpeed/ALCF/tokenizer.model export TOKENIZER_MODEL="${TOKENIZER_MODEL:-${tm}}" # USE TOKENIZER_MODEL from env, else fallback from ^ - export TOKENIZER_FLAGS="--tokenizer-type Llama2Tokenizer --tokenizer-model ${TOKENIZER_MODEL}" - if [[ "${TOKENIZER_TYPE}" != "GPT2" ]]; then - echo "Using tokenizer: ${TOKENIZER_TYPE}. Setting up data with ${DATA_FILE_LIST-}" - setData "${dfl}" || exit - fi + _tokenizer_flags+=( + "--tokenizer-type ${TOKENIZER_TYPE}" + "--tokenizer-model ${TOKENIZER_MODEL}" + ) + # if [[ "${TOKENIZER_TYPE}" != "GPT2" ]]; then + echo "Using tokenizer: ${TOKENIZER_TYPE}. Setting up data with ${DATA_FILE_LIST:-}" + setData "${dfl}" || exit fi + export DATA_FLAGS="${_data_flags[*]}" + export TOKENIZER_FLAGS="${_tokenizer_flags[*]}" printf "[setData] DATA_FLAGS: %s\n" "$(printGreen "${DATA_FLAGS}")" printf "[setData] TOKENIZER_FLAGS: %s\n" "$(printMagenta "${TOKENIZER_FLAGS}")" } @@ -1059,7 +1100,7 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] export WEIGHT_SUM="${ws}" export DFL_STEM="${dfl_stem}" export DATA_CACHE_PATH="${dcp}" - export DATA_FLAGS="${DATA_FLAGS} --data-file-list ${DATA_FILE_LIST}" # --data-cache-path ${DATA_CACHE_PATH}" + # export DATA_FLAGS="${DATA_FLAGS} --data-file-list ${DATA_FILE_LIST}" # --data-cache-path ${DATA_CACHE_PATH}" echo "--------------------" echo "Updated environment:" printf "DATA_FILE_LIST: %s\n" "${DATA_FILE_LIST}" @@ -1071,6 +1112,30 @@ setData() { # ------------------------[dfl: abbrv. for DATA_FILE_LIST] echo "--------------------" } +generateDSconfig_new() { + cat <"${CONFIG_JSON}" + { + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "bf16": { + "enabled": true + }, + + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "wall_clock_breakdown" : false + } +EOT +} + ################################################################################ # generateDSconfig # @@ -1089,16 +1154,6 @@ generateDSconfig() { exit 1 fi done - # \"optimizer\": { - # \"type\": \"AdamW\", - # \"params\": { - # \"lr\": ${LR}, - # \"beta1\": 0.9, - # \"beta2\": 0.95, - # \"eps\": 1e-5, - # \"weight_decay\": 1e-1 - # } - # }, # \"scheduler\": { # \"type\": \"WarmupLR\", # \"params\": { @@ -1111,27 +1166,22 @@ generateDSconfig() { common="\ \"train_batch_size\": $GLOBAL_BATCH, \"train_micro_batch_size_per_gpu\": $MICRO_BATCH, + \"gradient_clipping\": 1.0, \"steps_per_print\": 1, \"gradient_accumulation_steps\": $GRAD_ACC_STEPS, + \"zero_force_ds_cpu_optimizer\": false, \"zero_allow_untested_optimizer\": true, - \"gradient_clipping\": 1.0, - \"activation_checkpointing\": { - \"partition_activations\": true, - \"contiguous_memory_optimization\": true - }, \"wall_clock_breakdown\": false," - flops_profiler="\ - \"flops_profiler\": { - \"enabled\": true, - \"profile_step\": 2, - \"module_depth\": -1, - \"top_modules\": 1, - \"detailed\": true, - \"output_file\": null - }" + # if [[ "${USE_ACTIVATION_CHECKPOINTING}" == 1 ]]; then + # activation_checkpointing="\ + # \"activation_checkpointing\": { + # \"partition_activations\": true, + # \"contiguous_memory_optimization\": true + # }," + # fi if [[ $DTYPE == "bf16" ]]; then + # \"communication_data_type\": \"bf16\", dtype="\ - \"communication_data_type\": \"bf16\", \"fp16\": { \"enabled\": false, \"loss_scale\": 0, @@ -1160,6 +1210,38 @@ generateDSconfig() { else dtype="\"communication_data_type\": \"fp32\"," fi + if [[ "${OPT:-}" == "ds.adamw" ]]; then + optimizer="\ + \"optimizer\": { + \"type\": \"AdamW\", + \"params\": { + \"lr\": ${LR}, + \"beta1\": ${ADAM_BETA1}, + \"beta2\": ${ADAM_BETA2}, + \"eps\": ${ADAM_EPS}, + \"weight_decay\": 1e-1 + }, + }," + elif [[ "${OPT:-}" == "ds.onebitlamb" ]]; then + optimizer="\ + \"optimizer\": { + \"type\": \"OneBitLamb\", + \"params\": { + \"lr\": 11e-3, + \"max_coeff\": 0.3, + \"min_coeff\": 0.01, + \"freeze_step\": 1000, + \"cuda_aware\": false, + \"comm_backend_name\": \"${BE}\", + \"coeff_beta\": 0.9, + \"factor_max\": 4.0, + \"factor_min\": 0.5, + \"factor_threshold\": 0.1 + } + }," + else + optimizer="" + fi if [[ "${ZERO_STAGE}" == 3 ]]; then # \"mics_shard_size\": 2, zero="\ @@ -1185,8 +1267,7 @@ generateDSconfig() { }," # elif [[ $ZERO_STAGE == 2 ]]; then elif [[ "${ZERO_STAGE}" == 2 || "${ZERO_STAGE}" == 1 ]]; then - # if [[ -n "${CPU_OPTIMIZER}" ]]; then - if [[ "${CPU_OPTIMIZER:-0}" != 0 ]]; then + if [[ -n "${CPU_OPTIMIZER:-}" ]]; then echo "!!!! CAUGHT CPU_OPTIMIZER !!!!" zero="\ \"zero_optimization\": { @@ -1215,18 +1296,27 @@ generateDSconfig() { else extra="\ \"comms_logger\": { - \"enabled\": true, + \"enabled\": ${COMMS_LOGGER:-false}, \"verbose\": false, - \"prof_all\": true, \"debug\": false }," fi else echo 'Please add the correct config set!!!' fi + flops_profiler="\ + \"flops_profiler\": { + \"enabled\": true, + \"profile_step\": 2, + \"module_depth\": -1, + \"top_modules\": 1, + \"detailed\": true, + \"output_file\": null + }" cat <"$1" { $common +$optimizer $zero $dtype $extra @@ -1304,6 +1394,87 @@ printWhite() { printf "\e[1;37m%s\e[0m\n" "$@" } +reset_env() { + custom_vars=( + NO_FLASH_ATTN + USE_FLASH_ATTN + TP + PP + SP + FLASH_ARG + OPT + ADAM_BETA1 + ADAM_BETA2 + ADAM_EPS + WEIGHT_DECAY + HEADS + NLAYERS + HIDDEN + NUM_KV_HEAD + FFN_HIDDEN_SIZE + SEQ + ZERO_STAGE + MICRO_BATCH + EVAL_ITERS + EVAL_INTERVAL + TIMING_LOG_LEVEL + ACT_CKPT_NUM_LAYERS + USE_ACTIVATION_CHECKPOINTING + GLOBAL_BATCH_MAX + GLOBAL_BATCH + TRAIN_TOKENS + TRAIN_ITERS + MODEL_TYPE + LR + LR_WARMUP_FRAC + LR_DECAY_ITERS + LR_ARGS + CPU_OPTIMIZER + DS_CONFIG + OUTPUT_DIR + OUTPUT_LOG + CKPT_DIR + ds_args + EXEC + EXEC_STEM + DATA_FLAGS + TOKENIZER_TYPE + TOKENIZER_MODEL + TOKENIZER_FLAGS + DATA_FILE_LIST + NUM_DOCS + WEIGHT_SUM + DFL_STEM + DATA_CACHE_PATH + DOTENV_FILE + YEAR + MONTH + DAY + TODAY + STARTED_AT + LAUNCHER + data_cache_path + DEFAULTS + ) + # LLAMA_ARGS + printf "Unsetting custom vars: %s\n" "${custom_vars[*]}" + unset "${custom_vars[@]}" +} + +convert_ckpt_to_universal() { + if [[ "$#" -ne 1 ]]; then + echo "Usage: convert_ckpt_to_universal ckpt_dir" + echo "Expected one argument (ckpt_dir), received: $#" + exit 1 + fi + ckptdir=$1 + gs=$(cat "${ckptdir}/latest_checkpointed_iteration.txt") + src="${ckptdir}/global_step${gs}" + dst="${ckptdir}/global_step${gs}_universal" + convert_script="${PBS_O_WORKDIR}/deps/DeepSpeed/checkpoint/ds_to_universal.py" + python3 "${convert_script}" --input_folder "${src}" --output_folder "${dst}" +} + ########################### # call helpers_main() ########################### diff --git a/ALCF/requirements/requirements.txt b/ALCF/requirements/requirements.txt index e0a969358d..03541ba514 100644 --- a/ALCF/requirements/requirements.txt +++ b/ALCF/requirements/requirements.txt @@ -15,3 +15,4 @@ six numpy<2 schedulefree packaging>=20.0 +wandb diff --git a/ALCF/test_blendable_dataset.py b/ALCF/test_blendable_dataset.py index a3cabddd29..c119862142 100644 --- a/ALCF/test_blendable_dataset.py +++ b/ALCF/test_blendable_dataset.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import time +import json start_time = time.time() from mpi4py import MPI import os @@ -37,7 +38,7 @@ def print_rank_0(msg): os.makedirs(args.trace_dir, exist_ok=True) - +corpus_all = [] data_file_list = args.data_file_list print_rank_0(f"Reading data from {args.data_file_list}") files = [] @@ -51,6 +52,9 @@ def print_rank_0(msg): files.append(float(w)) files.append(fname) files.append(c) + if c not in corpus_all: + corpus_all.append(c) + splits_string="100,0,0" weights = np.array(weights) @@ -82,6 +86,40 @@ def print_rank_0(msg): print_rank_0(f"Total number of samples: {len(train_ds)}") print_rank_0(f"Weights set: {weights[:min(8, num_datasets)]}") + +def get_sample_info(blendable_dataset, idx): + # corpus dataset + cd = blendable_dataset.dataset_index[idx] + # index within the corpus dataset + cds = blendable_dataset.dataset_sample_index[idx] + # dataset index within each corpus + fcd = blendable_dataset.datasets[cd].dataset_index[cds] + # sample index within the dataset + fcds = blendable_dataset.datasets[cd].dataset_sample_index[cds] + # corresponding data file + prefix = blendable_dataset.datasets[cd].dataset_builders[fcd].prefix + corpus = blendable_dataset.datasets[cd].dataset_builders[fcd].corpus + #v = blendable_dataset[idx]['text'] + #norm = np.linalg.norm(v) + return prefix, corpus, fcds + +num_batches = args.train_iters +print(f"global_batch_size: {args.global_batch_size}") +print(f"number of batches: {num_batches}") + +fout = open("samples_list.jsonl", "w") +if comm.rank == 0: + for i in range(num_batches): + ns_corpus = {} + for c in corpus_all: + ns_corpus[c] = 0 + for j in range(args.global_batch_size): + prefix, corpus, idx = get_sample_info(train_ds, i*args.global_batch_size+j) + ns_corpus[corpus] +=1 + fout.write(f"\u007b 'batch': {i}, 'sample': {j}, 'corpus': '{corpus}', 'prefix': '{prefix}', 'dataset_sample_index': {idx} \u007d\n") + fout.write(f"\u007b 'batch': {i}, 'histogram': {ns_corpus} \u007d \n") +comm.Barrier() +exit() start_build_dataloader = time.time() print_rank_0(f"Starting to build the data loader") rank_in_parallel_group = mpu.get_sequence_parallel_rank() diff --git a/megatron/arguments.py b/megatron/arguments.py index 2f52084329..9b0e6ccb1a 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1291,6 +1291,10 @@ def _add_data_args(parser): group.add_argument('--data-file-list', type=str, default=None, help='The file with the list of dataset and weights') + group.add_argument('--shuffle-sample-in-corpus', action='store_true', help="Whether to shuffle the samples within in the dataset files") + + group.add_argument('--blend-sample-in-corpus', action='store_true', help="Whether to blend different files in the same corpus") + group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index b23f6c84b3..78e43e7fed 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -16,7 +16,8 @@ from megatron.core import ModelParallelConfig from deepspeed.accelerator import get_accelerator - +from megatron.utils import Profile +dlp = Profile("PIPELINE") # Types Shape = Union[List[int], torch.Size] @@ -329,6 +330,7 @@ def _ring_exchange_wrapper(**kwargs): return tensor_recv_prev, tensor_recv_next, reqs +@dlp.log def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: """ Receive tensor from previous rank in pipeline (forward receive). @@ -353,7 +355,7 @@ def recv_forward(tensor_shape: Shape, config.timers('forward-recv').stop() return input_tensor - +@dlp.log def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: """Receive tensor from next rank in pipeline (backward receive). @@ -376,7 +378,7 @@ def recv_backward(tensor_shape: Shape, config.timers('backward-recv').stop() return output_tensor_grad - +@dlp.log def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None: """Send tensor to next rank in pipeline (forward send). @@ -397,7 +399,7 @@ def send_forward(output_tensor: torch.Tensor, if config.timers is not None: config.timers('forward-send').stop() - +@dlp.log def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None: """Send tensor to previous rank in pipeline (backward send). @@ -417,7 +419,7 @@ def send_backward(input_tensor_grad: torch.Tensor, if config.timers is not None: config.timers('backward-send').stop() - +@dlp.log def send_forward_recv_backward(output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: @@ -441,7 +443,7 @@ def send_forward_recv_backward(output_tensor: torch.Tensor, config.timers('forward-send-backward-recv').stop() return output_tensor_grad - +@dlp.log def send_backward_recv_forward(input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: @@ -465,7 +467,7 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor, config.timers('backward-send-forward-recv').stop() return input_tensor - +@dlp.log def send_forward_recv_forward(output_tensor: torch.Tensor, recv_prev: bool, tensor_shape: Shape, @@ -491,7 +493,7 @@ def send_forward_recv_forward(output_tensor: torch.Tensor, return input_tensor, wait_handles return input_tensor - +@dlp.log def send_backward_recv_backward(input_tensor_grad: torch.Tensor, recv_next: bool, tensor_shape: Shape, @@ -517,7 +519,7 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor, return output_tensor_grad, wait_handles return output_tensor_grad - +@dlp.log def send_forward_backward_recv_forward_backward( output_tensor: torch.Tensor, input_tensor_grad: torch.Tensor, diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index ba2e00b1ef..ab164fdc48 100755 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -6,14 +6,20 @@ import os import time +import logging import numpy as np import torch from deepspeed.accelerator import get_accelerator -from megatron import print_rank_0 +# from megatron import print_rank_0 from megatron.core import mpu from megatron.utils import Profile, PerfTrace from mpi4py import MPI + +from megatron.utils import get_logger + +log = get_logger(__name__, rank_zero_only=True) + dlp = Profile("DATASET") class BlendableDataset(torch.utils.data.Dataset): @dlp.log @@ -35,7 +41,7 @@ def __init__(self, datasets, weights, size, *, # Build indicies. @dlp.log def _build_indices(): - start_time = time.time() + start_time = time.perf_counter() dataset_index = np.zeros(self.size, dtype=np.int64) dataset_sample_index = np.zeros(self.size, dtype=np.int64) @@ -43,8 +49,10 @@ def _build_indices(): helpers.build_blending_indices(dataset_index, dataset_sample_index, weights, num_datasets, self.size, torch.distributed.get_rank() == 0) - print_rank_0('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + log.info( + "> elapsed time for building blendable dataset indices: " + f"{time.perf_counter() - start_time:.2f} (sec)" + ) return dataset_index, dataset_sample_index desc = "Blendable dataset\n\n" @@ -68,15 +76,15 @@ def _build_indices(): ' dataset, building indices on rank 0 ...', flush=True) dataset_index, dataset_sample_index = _build_indices() try: - print_rank_0(" > saving index map files") - start_time = time.time() + log.debug(" > saving index map files") + start_time = time.perf_counter() os.makedirs(os.path.dirname(index_path), exist_ok=True) with open(desc_path, 'wt') as fd: fd.write(desc) np.save(index_path, dataset_index, allow_pickle=True) np.save(sample_index_path, dataset_sample_index, allow_pickle=True) - print_rank_0(f" > finished saving index map files in {time.time() - start_time} seconds") + log.info(f" > finished saving index map files in {time.perf_counter() - start_time} seconds") except OSError: print(f'There was an error trying to create the data cache directory ({data_cache_path})') print('or a file in it. This is set with the --data-cache-path argument. Please') @@ -93,21 +101,21 @@ def _build_indices(): torch.distributed.get_world_size() // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) // torch.distributed.get_world_size(group=mpu.get_sequence_parallel_group())): - print_rank_0("Data index creation unsuccessful, exiting.") + log.info("Data index creation unsuccessful, exiting.") exit() ''' torch.distributed.barrier(group=mpu.get_data_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_data_parallel_group()) - - start_time = time.time() - print_rank_0(f'> loading blendable dataset index: {index_path}') + + start_time = time.perf_counter() + log.info(f'> loading blendable dataset index: {index_path}') self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r') assert self.dataset_index.size == self.size - print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}') + log.info(f'> loading blendable dataset sample index: {sample_index_path}') self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r') assert self.dataset_sample_index.size == self.size - print_rank_0(f'> finished loading in {time.time() - start_time} seconds') + log.info(f'> finished loading in {time.perf_counter() - start_time} seconds') else: self.dataset_index, self.dataset_sample_index = _build_indices() @@ -119,7 +127,7 @@ def _build_indices(): raise RuntimeError('BlendedDataset size is improperly bounded') except IndexError: pass - print_rank_0('> size of blendable dataset: ' + log.info('> size of blendable dataset: ' '{} samples'.format(self.size)) @@ -133,4 +141,4 @@ def __getitem__(self, idx): return { "dataset_idx" : dataset_idx, **self.datasets[dataset_idx][sample_idx], - } + } diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py old mode 100755 new mode 100644 index 0cf97356a4..d09f08d63a --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -9,67 +9,96 @@ import numpy as np import torch from deepspeed.accelerator import get_accelerator -from megatron import print_rank_0, is_rank_0, get_args +from megatron import is_rank_0, get_args from megatron.core import mpu -from megatron.data import helpers +from megatron.data import helpers # type:ignore from megatron.data.blendable_dataset import BlendableDataset -from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_datasets_corpuses_weights_and_num_samples +from megatron.data.dataset_utils import ( + get_datasets_weights_and_num_samples, + get_datasets_corpuses_weights_and_num_samples, +) from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset -from megatron.utils import PerfTrace, Profile +from megatron.utils import PerfTrace, Profile, get_logger from mpi4py import MPI dlp = Profile("DATASET") +log = get_logger(__name__, rank_zero_only=True) + + @dlp.log -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - train_data_prefix=None, - valid_data_prefix=None, - test_data_prefix=None, - return_doc_ids=False, *, - data_cache_path=None): +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" if data_prefix: - print_rank_0("Single data path provided for train, valid & test") + log.debug("Single data path provided for train, valid & test") # Single dataset. if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) # Blending dataset. # Parse the values. - output = get_datasets_corpuses_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_corpuses_weights_and_num_samples( + data_prefix, train_valid_test_num_samples + ) prefixes, corpuses, weights, datasets_train_valid_test_num_samples = output corpus_list = sorted(set(corpuses)) train_num_samples, valid_num_samples, test_num_samples = map( - sum, - zip(*datasets_train_valid_test_num_samples) + sum, zip(*datasets_train_valid_test_num_samples) ) class DatasetBuilder: - ''' + """ This is for building individual dataset from each dataset file - ''' + """ + @dlp.log - def __init__(self, prefix, corpus, data_impl, splits_string, - num_samples, seq_length, seed, skip_warmup, - return_doc_ids, - data_cache_path=data_cache_path, name='train'): + def __init__( + self, + prefix, + corpus, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + return_doc_ids, + data_cache_path=data_cache_path, + name="train", + ): self.prefix = prefix self.data_impl = data_impl self.splits_string = splits_string - if name == 'train': + if name == "train": self.num_samples = num_samples[0] - elif name == 'valid': + elif name == "valid": self.num_samples = num_samples[1] else: self.num_samples = num_samples[2] @@ -84,279 +113,437 @@ def __init__(self, prefix, corpus, data_impl, splits_string, self.desc = prefix + f"{self.num_samples}" + f"{seq_length}" + f"{seed}" self.build = False self.corpus = corpus + @dlp.log def Build(self): - self.dataset = _build_train_valid_test_datasets_single(self.prefix, self.data_impl, self.splits_string, - self.num_samples_train_valid_test, self.seq_length, self.seed, self.skip_warmup, self.name, self.return_doc_ids, - data_cache_path=self.data_cache_path) + self.dataset = _build_train_valid_test_datasets_single( + self.prefix, + self.data_impl, + self.splits_string, + self.num_samples_train_valid_test, + self.seq_length, + self.seed, + self.skip_warmup, + self.name, + self.return_doc_ids, + data_cache_path=self.data_cache_path, + ) self.build = True return self.dataset - class BuildConcatDataset(torch.utils.data.Dataset): + class BuildCorpusDataset(torch.utils.data.Dataset): @dlp.log def __init__(self, dataset_builders): self.dataset_builders = dataset_builders self.num_datasets = len(dataset_builders) self.num_samples = np.sum([d.num_samples for d in dataset_builders]) - self.indices=np.zeros((self.num_samples, 2), dtype=np.uint64) - self.desc="ConcatDataset:" - m = 0 + self.indices = np.zeros((self.num_samples, 2), dtype=np.uint64) + self.desc = "CorpusDataset:" + # m = 0 num_samples_list = np.array([d.num_samples for d in dataset_builders]) self.num_samples = np.sum(num_samples_list) - def _build_indices(): + args = get_args() + + @dlp.log + def _build_indices_blended(): + start_time = time.time() + dataset_index = np.zeros(self.num_samples, dtype=np.int64) + dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64) + weights = num_samples_list / self.num_samples + helpers.build_blending_indices( + dataset_index, dataset_sample_index, + weights, self.num_datasets, self.num_samples, + torch.distributed.get_rank() == 0) + log.debug(f"> elapsed time for building blendable dataset indices for corpus {self.dataset_builders[0].corpus}: " + "{:.2f} (sec)".format(time.time() - start_time)) + return dataset_index, dataset_sample_index + + + def _build_indices_concat(): start_time = time.time() dataset_index = np.zeros(self.num_samples, dtype=np.int64) dataset_sample_index = np.zeros(self.num_samples, dtype=np.int64) - helpers.build_concat_indices(dataset_index, dataset_sample_index, - num_samples_list, - self.num_datasets, - torch.distributed.get_rank()==0) - print_rank_0('> elapsed time for building concat dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + helpers.build_concat_indices( + dataset_index, + dataset_sample_index, + num_samples_list, + self.num_datasets, + torch.distributed.get_rank() == 0, + ) + log.debug( + "> elapsed time for building concat dataset indices: " + "{:.2f} (sec)".format(time.time() - start_time) + ) return dataset_index, dataset_sample_index - self.dataset_index, self.dataset_sample_index = _build_indices() + if args.blend_sample_in_corpus: + self.dataset_index, self.dataset_sample_index = _build_indices_blended() + else: + self.dataset_index, self.dataset_sample_index = _build_indices_concat() + + np_rng = np.random.RandomState(seed=dataset_builders[0].seed) + self.shuffle_index = np.arange(self.num_samples) + if args.shuffle_sample_in_corpus: + np_rng.shuffle(self.shuffle_index) for i in range(self.num_datasets): self.desc += dataset_builders[i].prefix + "," - self.desc += f"-{self.num_samples}" + f"-{dataset_builders[0].seq_length}" + f"{dataset_builders[0].seed}" + log.info( + f"[BuildConcatDataset] Caught {args.shuffle_sample_in_corpus=} across" + f" {self.num_samples} samples" + ) + self.desc += ( + f"-{self.num_samples}" + + f"-{dataset_builders[0].seq_length}" + + f"{dataset_builders[0].seed}" + ) + def __len__(self): return self.num_samples @dlp.log def __getitem__(self, idx): - if idx >= self.num_samples: - print_rank_0(f"WARNING: index overflow encountered {idx} > {self.num_samples} for {self.dataset_builders[0].corpus}; will randomly pick one sample") - id = np.random.randint(self.num_samples) - else: - id = idx - i = self.dataset_index[idx] - j = self.dataset_sample_index[idx] + id_shuffle = self.shuffle_index[idx] + i = self.dataset_index[id_shuffle] + j = self.dataset_sample_index[id_shuffle] if self.dataset_builders[i].build: return self.dataset_builders[i].dataset[j] else: return self.dataset_builders[i].Build()[j] - - # Predetermine whether need to build the specific dataset or not. + # Predetermine whether need to build the specific dataset or not. start_time = time.time() - print_rank_0(" >>> Started building datasets in distributed way ... ") + log.debug(" >>> Started building datasets in distributed way ... ") a, b, c = [int(d) for d in splits_string.split(",")] - + train_datasets = [] valid_datasets = [] test_datasets = [] # Build individual datasets. - + args = get_args() @dlp.log - def build_corpus_datasets(dataset_type='train'): + def build_corpus_datasets(dataset_type="train"): start_time = time.time() - print_rank_0(f" >>> Building {dataset_type} corpus datasets ...") + log.debug(f" >>> Building {dataset_type} corpus datasets ...") datasets = [] corpus_builders = {} corpus_weights = {} for c in corpus_list: corpus_builders[c] = [] corpus_weights[c] = 0.0 - dataset_builders = [DatasetBuilder(prefixes[i], corpuses[i], data_impl, splits_string, - datasets_train_valid_test_num_samples[i], - seq_length, seed, skip_warmup, - return_doc_ids,data_cache_path, dataset_type) for i in range(len(weights))] - for i in range(torch.distributed.get_rank()//mpu.get_tensor_model_parallel_world_size(), len(weights), torch.distributed.get_world_size()//mpu.get_tensor_model_parallel_world_size()): + dataset_builders = [ + DatasetBuilder( + prefixes[i], + corpuses[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + return_doc_ids, + data_cache_path, + dataset_type, + ) + for i in range(len(weights)) + ] + for i in range( + torch.distributed.get_rank() + // mpu.get_tensor_model_parallel_world_size(), + len(weights), + torch.distributed.get_world_size() + // mpu.get_tensor_model_parallel_world_size(), + ): dataset_builders[i].Build() - print_rank_0(f" >>> Finished building individual datasets in {time.time() - start_time} seconds") + log.debug( + f" >>> Finished building individual datasets in {time.time() - start_time} seconds" + ) start_concating_time = time.time() for i, d in zip(range(len(weights)), dataset_builders): corpus_builders[d.corpus].append(d) corpus_weights[d.corpus] += weights[i] total = 0 - print_rank_0(" > number of samples for each corpus ") - corpus_weights_achieved={} + log.debug(" > number of samples for each corpus ") + corpus_weights_achieved = {} for c in corpus_list: - datasets.append(BuildConcatDataset(corpus_builders[c])) + datasets.append(BuildCorpusDataset(corpus_builders[c])) total += datasets[-1].num_samples - corpus_weights_achieved[c] = float(datasets[-1].num_samples)/train_num_samples - print_rank_0(f" {c}: {datasets[-1].num_samples} w={corpus_weights_achieved[c]} (expected: {corpus_weights[c]})") - - print_rank_0(f" > total number of samples: {total}") - print_rank_0(f" >>> Finished concatenating datasets in {time.time() - start_concating_time} seconds") - print_rank_0(f" >>> Finished building {dataset_type} corpus datasets in {time.time() - start_time} seconds") + corpus_weights_achieved[c] = ( + float(datasets[-1].num_samples) / train_num_samples + ) + log.debug( + f" {c}: {datasets[-1].num_samples} w={corpus_weights_achieved[c]} (expected: {corpus_weights[c]})" + ) + + log.debug(f" > total number of samples: {total}") + log.debug( + f" >>> Finished concatenating datasets in {time.time() - start_concating_time} seconds" + ) + log.debug( + f" >>> Finished building {dataset_type} corpus datasets in {time.time() - start_time} seconds" + ) return datasets, [corpus_weights_achieved[c] for c in corpus_list] + train_weights = None if a > 0: - train_datasets, train_weights = build_corpus_datasets('train') - + train_datasets, train_weights = build_corpus_datasets("train") + valid_weights = None if b > 0: - valid_datasets, valid_weights = build_corpus_datasets('valid') - - if c > 0: - test_datasets, test_weights = build_corpus_datasets('test') + valid_datasets, valid_weights = build_corpus_datasets("valid") + test_weights = None + if c > 0: + test_datasets, test_weights = build_corpus_datasets("test") # This barrier is critical to make sure that all the datasets are built once # and the metadata were written to the cache folder before other ranks touch them - print_rank_0(f" >>> Rank 0 - finished building datasets in {time.time() - start_time} seconds") + log.debug( + f" >>> Rank 0 - finished building datasets in {time.time() - start_time} seconds" + ) torch.distributed.barrier(group=mpu.get_data_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_data_parallel_group()) - print_rank_0(f" >>> Finished building datasets (all ranks) in distributed way in {time.time() - start_time} seconds") - print_rank_0(f" >>> Starting to build BlendableDataset") + log.debug( + f" >>> Finished building datasets (all ranks) in distributed way in {time.time() - start_time} seconds" + ) + log.debug(" >>> Starting to build BlendableDataset") # Blend. start_time = time.time() blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, train_weights, train_num_samples, - data_cache_path=data_cache_path) + if train_datasets and train_weights: + blending_train_dataset = BlendableDataset( + train_datasets, + train_weights, + train_num_samples, + data_cache_path=data_cache_path, + ) blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, valid_weights, valid_num_samples, - data_cache_path=data_cache_path) + if valid_datasets and valid_weights: + blending_valid_dataset = BlendableDataset( + valid_datasets, + valid_weights, + valid_num_samples, + data_cache_path=data_cache_path, + ) blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, test_weights, test_num_samples, - data_cache_path=data_cache_path) + if test_datasets and test_weights: + blending_test_dataset = BlendableDataset( + test_datasets, + test_weights, + test_num_samples, + data_cache_path=data_cache_path, + ) end_time = time.time() - print_rank_0(f" >>> Finished building BlendableDataset in {end_time - start_time} seconds") - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + log.debug( + f" >>> Finished building BlendableDataset in {end_time - start_time} seconds" + ) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) else: - print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") + log.debug( + "Separate data paths provided for train, valid & test. Split string will be ignored." + ) train_dataset, valid_dataset, test_dataset = None, None, None # Single dataset. if train_data_prefix is not None: - train_dataset = build_dataset("train", train_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[0], - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + train_dataset = build_dataset( + "train", + train_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[0], + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) if valid_data_prefix is not None: - valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[1], - seq_length, seed, False, - data_cache_path=data_cache_path) - + valid_dataset = build_dataset( + "valid", + valid_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[1], + seq_length, + seed, + False, + data_cache_path=data_cache_path, + ) if test_data_prefix is not None: - test_dataset = build_dataset("test", test_data_prefix, data_impl, - splits_string, - train_valid_test_num_samples[2], - seq_length, seed, False, - data_cache_path=data_cache_path) + test_dataset = build_dataset( + "test", + test_data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples[2], + seq_length, + seed, + False, + data_cache_path=data_cache_path, + ) return (train_dataset, valid_dataset, test_dataset) + @dlp.log -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, - return_doc_ids=False, *, - data_cache_path=None): +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - print_rank_0(' > dataset split:') + log.debug(" > dataset split:") def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + log.debug(" {}:".format(name)) + log.debug( + " document indices in [{}, {}) total of {} " "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_length, seed, - return_doc_ids, - data_cache_path=data_cache_path) + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, + seed, + return_doc_ids, + data_cache_path=data_cache_path, + ) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) + @dlp.log -def _build_train_valid_test_datasets_single(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup, name, - return_doc_ids=False, *, - data_cache_path=None): +def _build_train_valid_test_datasets_single( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, + name, + return_doc_ids=False, + *, + data_cache_path=None, +): """Build train, valid, and test datasets.""" # Each rank print out information - print_rank_0(f" >> building dataset for {data_prefix}") + log.debug(f" >> building dataset for {data_prefix}") # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - print_rank_0(' > dataset split:') + log.debug(" > dataset split:") def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + log.debug(" {}:".format(name)) + log.debug( + " document indices in [{}, {}) total of {} " "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_length, seed, - return_doc_ids, - data_cache_path=data_cache_path) + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, + seed, + return_doc_ids, + data_cache_path=data_cache_path, + ) return dataset - if name.find("train")!=-1: - return build_dataset(0, 'train') - if name.find("valid")!=-1: - return build_dataset(1, 'valid') - if name.find("test")!=-1: - return build_dataset(2, 'test') + + if name.find("train") != -1: + return build_dataset(0, "train") + if name.find("valid") != -1: + return build_dataset(1, "valid") + if name.find("test") != -1: + return build_dataset(2, "test") + @dlp.log -def build_dataset(dataset_name, data_prefix, data_impl, - splits_string, num_samples, - seq_length, seed, skip_warmup, - *, - data_cache_path=None): +def build_dataset( + dataset_name, + data_prefix, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + *, + data_cache_path=None, +): dataset = None if len(data_prefix) == 1: - dataset = _build_dataset(dataset_name, data_prefix[0], data_impl, - splits_string, num_samples, seq_length, - seed, skip_warmup, - data_cache_path=data_cache_path) + dataset = _build_dataset( + dataset_name, + data_prefix[0], + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) else: # Blending dataset. # Parse the values. @@ -367,73 +554,108 @@ def build_dataset(dataset_name, data_prefix, data_impl, # Build individual datasets. datasets = [] for i in range(len(prefixes)): - ds = _build_dataset(dataset_name, prefixes[i], data_impl, - splits_string, dataset_num_samples[i], - seq_length, seed, skip_warmup, - data_cache_path=data_cache_path) + ds = _build_dataset( + dataset_name, + prefixes[i], + data_impl, + splits_string, + dataset_num_samples[i], + seq_length, + seed, + skip_warmup, + data_cache_path=data_cache_path, + ) if ds: datasets.append(ds) if datasets: - dataset = BlendableDataset(datasets, weights, num_samples, - data_cache_path=data_cache_path) + dataset = BlendableDataset( + datasets, weights, num_samples, data_cache_path=data_cache_path + ) return dataset + @dlp.log -def _build_dataset(dataset_name, data_prefix, data_impl, splits_string, - num_samples, seq_length, seed, skip_warmup, - *, - data_cache_path=None): +def _build_dataset( + dataset_name, + data_prefix, + data_impl, + splits_string, + num_samples, + seq_length, + seed, + skip_warmup, + *, + data_cache_path=None, +): """ Build dataset. This method is called when individual train, valid, test datasets are provided """ # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] - print_rank_0(' {}:'.format(dataset_name)) - print_rank_0(' document indices in [0, {}) total of {} ' - 'documents'.format(total_num_of_documents, total_num_of_documents)) - - documents = np.arange(start=0, stop=total_num_of_documents, - step=1, dtype=np.int32) - - dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset, - splits_string, num_samples, seq_length, seed, - data_cache_path=data_cache_path) + log.debug(" {}:".format(dataset_name)) + log.debug( + " document indices in [0, {}) total of {} " "documents".format( + total_num_of_documents, total_num_of_documents + ) + ) + + documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) + + dataset = GPTDataset( + dataset_name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + data_cache_path=data_cache_path, + ) return dataset + @dlp.log def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): """Build indexed dataset.""" - print_rank_0(' > building dataset index ...') + log.debug(" > building dataset index ...") start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) - print_rank_0(' > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time)) - print_rank_0(' number of documents: {}'.format( - indexed_dataset.sizes.shape[0])) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + log.debug( + " > finished creating indexed dataset in {:4f} " "seconds".format( + time.time() - start_time + ) + ) + log.debug(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) return indexed_dataset class GPTDataset(torch.utils.data.Dataset): @dlp.log - def __init__(self, name, data_prefix, documents, indexed_dataset, - splits_string, num_samples, seq_length, seed, - return_doc_ids=False, *, - data_cache_path=None): - + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + return_doc_ids=False, + *, + data_cache_path=None, + ): self.name = name self.indexed_dataset = indexed_dataset self.return_doc_ids = return_doc_ids @@ -443,20 +665,29 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ - _build_index_mappings(self.name, data_prefix, - documents, self.indexed_dataset.sizes, - splits_string, num_samples, seq_length, seed, - data_cache_path=data_cache_path) - + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = ( + _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + splits_string, + num_samples, + seq_length, + seed, + data_cache_path=data_cache_path, + ) + ) def __len__(self): # -1 is due to data structure used to retieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) return self.sample_idx.shape[0] - 1 + @dlp.log def __getitem__(self, idx): args = get_args() + assert args is not None orig_idx = idx # Get the shuffled index. try: @@ -465,21 +696,24 @@ def __getitem__(self, idx): if is_rank_0(): import json from rich import print_json + print(exc) print( - '\n'.join( - ['-------------------------------------------------', - f'Trying to access {idx=} from self.shuffle_idx,', - f'but {len(self.shuffle_idx)=}', - '-------------------------------------------------'] + "\n".join( + [ + "-------------------------------------------------", + f"Trying to access {idx=} from self.shuffle_idx,", + f"but {len(self.shuffle_idx)=}", + "-------------------------------------------------", + ] ) ) print_json( json.dumps( { - 'doc_idx': len(self.doc_idx), - 'sample_idx': len(self.sample_idx), - 'shuffle_idx': len(self.shuffle_idx), + "doc_idx": len(self.doc_idx), + "sample_idx": len(self.sample_idx), + "shuffle_idx": len(self.shuffle_idx), }, indent=4, ) @@ -493,45 +727,57 @@ def __getitem__(self, idx): doc_ids = [] if doc_index_f == doc_index_l: doc_ids.append(self.doc_idx[doc_index_f]) - sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1) + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) else: # Otherwise, get the rest of the initial document. doc_ids.append(self.doc_idx[doc_index_f]) - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f)] + sample_list = [ + self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): doc_ids.append(self.doc_idx[i]) sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) # And finally add the relevant portion of last document. doc_ids.append(self.doc_idx[doc_index_l]) - sample_list.append(self.indexed_dataset.get( - self.doc_idx[doc_index_l], - length=offset_l + 1)) + sample_list.append( + self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) sample = np.concatenate(sample_list) - text_name = 'text' + text_name = "text" if args.use_dataset_only: - text_name = 'input_ids' + text_name = "input_ids" sample_dict = {text_name: np.array(sample, dtype=np.int64)} if args.return_data_index: - sample_dict.update({'index': np.array([orig_idx], dtype=np.int64)}) + sample_dict.update({"index": np.array([orig_idx], dtype=np.int64)}) - if self.return_doc_ids: # for retro preprocessing - sample_dict.update({'doc_ids': np.array(doc_ids, dtype=np.int64)}) + if self.return_doc_ids: # for retro preprocessing + sample_dict.update({"doc_ids": np.array(doc_ids, dtype=np.int64)}) if args.use_dataset_only: - sample_dict.update({'labels': np.array(sample, dtype=np.int64)}) + sample_dict.update({"labels": np.array(sample, dtype=np.int64)}) return sample_dict + @dlp.log -def _build_index_mappings(name, data_prefix, documents, sizes, - splits_string, num_samples, seq_length, seed, - *, - data_cache_path): +def _build_index_mappings( + name, + data_prefix, + documents, + sizes, + splits_string, + num_samples, + seq_length, + seed, + *, + data_cache_path, +): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. sample-idx: is the start document index and document offset for each @@ -539,10 +785,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes, shuffle-idx: maps the sample index into a random index into sample-idx. """ args = get_args() + assert args is not None # Number of tokens in each epoch and number of required epochs. tokens_per_epoch = _num_tokens(documents, sizes) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) - if args.train_data_exact_num_epochs is not None and name == 'train': + if args.train_data_exact_num_epochs is not None and name == "train": num_epochs = args.train_data_exact_num_epochs # rng state @@ -557,13 +804,13 @@ def _build_index_mappings(name, data_prefix, documents, sizes, desc += f"Sequence length {seq_length}\n" desc += f"Random seed {seed}\n" desc += f"Split {splits_string}\n" - desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest() desc_filename = desc_hash + ".dsc" - doc_idx_filename = desc_hash + '_doc_idx.npy' - sample_idx_filename = desc_hash + '_sample_idx.npy' - shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + doc_idx_filename = desc_hash + "_doc_idx.npy" + sample_idx_filename = desc_hash + "_sample_idx.npy" + shuffle_idx_filename = desc_hash + "_shuffle_idx.npy" - if name == 'train': + if name == "train": # force to use certain index files if args.train_desc_path is not None: desc_filename = args.train_desc_path @@ -578,15 +825,15 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # duplication, then look in data-cache-path if specified, # If nothing is found, use the last path looked in build_indices = True - prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + prefixes = [os.path.join(os.path.dirname(data_prefix), "index-cache")] if data_cache_path is not None: prefixes.append(data_cache_path) for prefix in prefixes: idx_path = { - 'desc': os.path.join(prefix, desc_filename), - 'doc': os.path.join(prefix, doc_idx_filename), - 'sample': os.path.join(prefix, sample_idx_filename), - 'shuffle': os.path.join(prefix, shuffle_idx_filename) + "desc": os.path.join(prefix, desc_filename), + "doc": os.path.join(prefix, doc_idx_filename), + "sample": os.path.join(prefix, sample_idx_filename), + "shuffle": os.path.join(prefix, shuffle_idx_filename), } for f in idx_path.values(): if not os.path.isfile(f): @@ -595,15 +842,17 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Found our files! build_indices = False break - data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_dir = os.path.dirname(idx_path["desc"]) data_cache_success = True # Build the indexed mapping if not exist. if build_indices: - # Since this function will be called by all the rank in the very beginning. Therefore, we assume that all the - # ranks will first create the document files, and then read it. + # Since this function will be called by all the rank in the very beginning. Therefore, we assume that all the + # ranks will first create the document files, and then read it. # There will not be contension effects going on either - print_rank_0(f" > WARNING: could not find index map files, building on rank {torch.distributed.get_rank()}") + log.warning( + f" > WARNING: could not find index map files, building on rank {torch.distributed.get_rank()}" + ) # For the last epoch, decide whether include the entire epoch # in the global shuffle or not. @@ -612,64 +861,80 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # not mean anything. if num_epochs == 1: separate_last_epoch = False - print_rank_0(' > only one epoch required, setting ' - 'separate_last_epoch to False') + log.debug( + " > only one epoch required, setting " "separate_last_epoch to False" + ) else: # Get the number of samples for the last epoch num_samples_from_epochs_minus_one = ( - (num_epochs - 1) * tokens_per_epoch - 1) // seq_length - last_epoch_num_samples = num_samples - \ - num_samples_from_epochs_minus_one - assert last_epoch_num_samples >= 0, \ - 'last epoch number of samples should be non-negative.' + (num_epochs - 1) * tokens_per_epoch - 1 + ) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert ( + last_epoch_num_samples >= 0 + ), "last epoch number of samples should be non-negative." num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ - 'last epoch number of samples exceeded max value.' + assert last_epoch_num_samples <= ( + num_samples_per_epoch + 1 + ), "last epoch number of samples exceeded max value." # If we have less than 80% of the samples for the last epoch, # seperate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can # be adjusted if needed. - separate_last_epoch = (last_epoch_num_samples < - int(0.80 * num_samples_per_epoch)) + separate_last_epoch = last_epoch_num_samples < int( + 0.80 * num_samples_per_epoch + ) if separate_last_epoch: - string = ' > last epoch number of samples ({}) is smaller '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to True' + string = ( + " > last epoch number of samples ({}) is smaller " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to True" + ) else: - string = ' > last epoch number of samples ({}) is larger '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to False' - print_rank_0(string.format(last_epoch_num_samples, - num_samples_per_epoch)) - + string = ( + " > last epoch number of samples ({}) is larger " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to False" + ) + log.debug(string.format(last_epoch_num_samples, num_samples_per_epoch)) try: os.makedirs(data_cache_dir, exist_ok=True) # description - with open(idx_path['desc'], 'wt') as fd: + with open(idx_path["desc"], "wt") as fd: fd.write(desc) # doc-idx. start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng, - separate_last_epoch) - np.save(idx_path['doc'], doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(idx_path["doc"], doc_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # sample-idx. start_time = time.time() # Use C++ implementation for speed. # First compile and then import. from megatron.data import helpers + assert doc_idx.dtype == np.int32 assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch, torch.distributed.get_rank()==0) - np.save(idx_path['sample'], sample_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + sample_idx = helpers.build_sample_idx( + sizes, + doc_idx, + seq_length, + num_epochs, + tokens_per_epoch, + torch.distributed.get_rank() == 0, + ) + np.save(idx_path["sample"], sample_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # shuffle-idx. start_time = time.time() # -1 is due to data structure used to retieve the index: @@ -678,35 +943,46 @@ def _build_index_mappings(name, data_prefix, documents, sizes, num_samples_ = num_samples_from_epochs_minus_one else: num_samples_ = sample_idx.shape[0] - 1 - shuffle_idx = _build_shuffle_idx(num_samples_, - sample_idx.shape[0] - 1, np_rng) - np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) + shuffle_idx = _build_shuffle_idx( + num_samples_, sample_idx.shape[0] - 1, np_rng + ) + np.save(idx_path["shuffle"], shuffle_idx, allow_pickle=True) + log.debug( + " > elasped time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) except OSError: - print(f'There was an error trying to create the data cache directory ({data_cache_dir})') - print('or a file in it. This defaults to a directory "index-cache" within the directory') - print('the data files are in and can be set with the --data-cache-path argument. Please') - print('ensure you have write access to this directory or specify one that you do have') - print('write access to.') + print( + f"There was an error trying to create the data cache directory ({data_cache_dir})" + ) + print( + 'or a file in it. This defaults to a directory "index-cache" within the directory' + ) + print( + "the data files are in and can be set with the --data-cache-path argument. Please" + ) + print( + "ensure you have write access to this directory or specify one that you do have" + ) + print("write access to.") data_cache_success = False # Load mappings. start_time = time.time() - print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") - doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path["doc"], allow_pickle=True, mmap_mode="r") - print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") - sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path["sample"], allow_pickle=True, mmap_mode="r") - print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") - shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + log.debug(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path["shuffle"], allow_pickle=True, mmap_mode="r") - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - sample_idx.shape[0])) - print_rank_0(' total number of epochs: {}'.format(num_epochs)) + log.debug( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + log.debug(" total number of samples: {}".format(sample_idx.shape[0])) + log.debug(" total number of epochs: {}".format(num_epochs)) return doc_idx, sample_idx, shuffle_idx, desc, desc_hash @@ -730,25 +1006,26 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples): if ((total_tokens - 1) // seq_length) >= num_samples: return num_epochs + @dlp.log def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): """Build an array with length = number-of-epochs * number-of-dcuments. Each index is mapped to a corresponding document.""" if not separate_last_epoch or num_epochs == 1: - doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] doc_idx[:] = documents doc_idx = doc_idx.reshape(-1) doc_idx = doc_idx.astype(np.int32) np_rng.shuffle(doc_idx) return doc_idx - doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) return np.concatenate((doc_idx_first, doc_idx_last)) + @dlp.log -def _build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch): +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): """Sample index mapping is a 2D array with sizes [number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is the @@ -782,7 +1059,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length, # Note that -1 here is for the same reason we have -1 in # `_num_epochs` calculations. if remaining_seq_length <= 0: - doc_offset += (remaining_seq_length + doc_length - 1) + doc_offset += remaining_seq_length + doc_length - 1 remaining_seq_length = 0 else: # Otherwise, start from the begining of the next document. @@ -795,24 +1072,28 @@ def _build_sample_idx(sizes, doc_idx, seq_length, return sample_idx + @dlp.log def _build_shuffle_idx(num_samples, total_size, np_rng): """Build the range [0, size) and shuffle.""" - print_rank_0(' > building shuffle index with split [0, {}) and [{}, {}) ' - '...'.format(num_samples, num_samples, total_size)) + log.debug( + " > building shuffle index with split [0, {}) and [{}, {}) " "...".format( + num_samples, num_samples, total_size + ) + ) dtype_ = np.uint32 if total_size >= (np.iinfo(np.uint32).max - 1): dtype_ = np.int64 - shuffle_idx_first = np.arange(start=0, stop=num_samples, - step=1, dtype=dtype_) + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) np_rng.shuffle(shuffle_idx_first) if num_samples == total_size: return shuffle_idx_first - shuffle_idx_last = np.arange(start=num_samples, stop=total_size, - step=1, dtype=dtype_) + shuffle_idx_last = np.arange( + start=num_samples, stop=total_size, step=1, dtype=dtype_ + ) np_rng.shuffle(shuffle_idx_last) return np.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1eb9b7842b..e2a0c4751f 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -15,17 +15,22 @@ from functools import lru_cache import os + import shutil import struct from itertools import accumulate import numpy as np import torch -from megatron import print_rank_0 -from megatron.utils import Profile + +# from megatron import print_rank_0 +from megatron.utils import Profile, get_logger + +log = get_logger(__name__) dlp = Profile("DATASET") + def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 @@ -34,28 +39,32 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['lazy', 'cached', 'mmap'] + return ["lazy", "cached", "mmap"] def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': - return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) else: return IndexedDatasetBuilder(out_file) @@ -63,22 +72,24 @@ def make_builder(out_file, impl, vocab_size=None): def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -114,11 +125,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -131,38 +142,41 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() self.path = path self.data_file = None self.read_index(path) + @dlp.log def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -176,7 +190,7 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -185,7 +199,7 @@ def __getitem__(self, idx): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -205,8 +219,8 @@ def size(self, index): @staticmethod def exists(path): - return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) ) @property @@ -215,7 +229,6 @@ def supports_prefetch(self): class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -224,6 +237,7 @@ def __init__(self, path): @property def supports_prefetch(self): return True + @dlp.log def prefetch(self, indices): if all(i in self.cache_index for i in indices): @@ -240,7 +254,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx: ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -255,10 +269,10 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx: ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -278,15 +292,17 @@ class IndexedDatasetBuilder(object): np.float32: 4, np.float64: 8, } + @dlp.log def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] self.doc_idx = [0] + @dlp.log def add_item(self, tensor): bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) @@ -297,6 +313,7 @@ def add_item(self, tensor): def end_document(self): self.doc_idx.append(len(self.sizes)) + @dlp.log def merge_file_(self, another_file): index = IndexedDataset(another_file) @@ -315,7 +332,7 @@ def merge_file_(self, another_file): self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -325,21 +342,22 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= 2 and accelerator is not None: accelerator.empty_cache() - # XXX: [saforem2]: ---------------------------------------------------- # Is `num_zeros_in_grad` worth calculating (/ implementing) ?? # the `Megatron`-specific implementation is at: @@ -1406,6 +1406,16 @@ def evaluate_and_print_results( config, verbose, ) + key = "test" if test else "val" + if wandb is not None and wandb.run is not None: + wandb.log({ + f"{key}/iteration": iteration, + **{f"{key}/{k}": v for k, v in total_loss_dict.items()}, + **{ + f"{key}/ppl_{k}": math.exp(min(20, v.item())) + for k, v in total_loss_dict.items() + }, + }) string = " validation loss at {} | ".format(prefix) for key in total_loss_dict: string += f"{key} value={total_loss_dict[key].item():.6f}" @@ -1451,6 +1461,7 @@ def evaluate_and_print_results( log.info("-" * length) log.info(string) log.info("-" * length) + return total_loss_dict def cyclic_iter(iter): diff --git a/megatron/utils.py b/megatron/utils.py index 31c8e20508..dc1dea0b3a 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -1,76 +1,103 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """General utilities.""" import sys import os -import time import logging from typing import Optional -# from ezpz.dist import get_rank import torch from torch.nn.parallel import DistributedDataParallel as torchDDP from deepspeed.accelerator import get_accelerator +from megatron import get_args, get_adlr_autoresume, get_num_microbatches +from megatron.core import mpu +from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.model.module import param_is_not_shared +from megatron.model.rotary_pos_embedding import RotaryEmbedding + +import ezpz as ez + ACCELERATOR = get_accelerator() assert ACCELERATOR is not None if ACCELERATOR.device_name() == "cuda": try: - from apex.multi_tensor_apply import multi_tensor_applier # type: ignore + from apex.multi_tensor_apply import multi_tensor_applier # type:ignore import amp_C # type:ignore HAS_APEX = True except Exception: HAS_APEX = False -from megatron import get_args, get_adlr_autoresume, get_num_microbatches -from megatron.core import mpu -from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate -from megatron.model.module import param_is_not_shared -from megatron.model.rotary_pos_embedding import RotaryEmbedding - -import ezpz as ez - RANK = ez.get_rank() log = logging.getLogger(__name__) -# log.setLevel("INFO") if RANK == 0 else log.setLevel("CRITICAL") - +log.setLevel(os.environ.get("LOG_LEVEL", ("INFO" if RANK == 0 else "CRITICAL"))) _DLIO_PROFILER_EXIST = True +_DFTRACER_EXIST = True + +try: + import dftracer # type:ignore +except Exception: + _DFTRACER_EXIST = False + try: - import dlio_profiler # type: ignore + import dlio_profiler # type:ignore except Exception: _DLIO_PROFILER_EXIST = False -if _DLIO_PROFILER_EXIST: + +if _DFTRACER_EXIST: + from dftracer.logger import ( # type:ignore + dftracer as PerfTrace, + dft_fn as Profile, + DFTRACER_ENABLE as DFTRACER_ENABLE, + ) +elif _DLIO_PROFILER_EXIST: from dlio_profiler.logger import fn_interceptor as Profile # type:ignore from dlio_profiler.logger import dlio_logger as PerfTrace # type:ignore else: from functools import wraps - class Profile: - def __init__(self, type="PROFILER"): - self._start = time.perf_counter() - self.type = type + class Profile(object): + def __init__( + self, cat, name=None, epoch=None, step=None, image_idx=None, image_size=None + ): + return def log(self, func): return func - def iter(self, a): - return a + def log_init(self, func): + return func + + def iter(self, func, iter_name="step"): + return func def __enter__(self): - self._start = time.perf_counter() + return + + def __exit__(self, type, value, traceback): + return - def __exit__(self, *args, **kwargs): - dt = time.perf_counter() - self._start - log.info(f"{self.type} took: {dt:.6f}s") + def update( + self, epoch=None, step=None, image_idx=None, image_size=None, args={} + ): + return + + def flush(self): + return + + def reset(self): + return - class dlio_logger: + def log_static(self, func): + return + + class dftracer(object): def __init__( self, ): @@ -79,16 +106,29 @@ def __init__( def initialize_log(self, logfile=None, data_dir=None, process_id=-1): return - def iter(self, a): - return a + def get_time(self): + return + + def enter_event(self): + return + + def exit_event(self): + return + + def log_event(self, name, cat, start_time, duration, string_args=None): + return - PerfTrace = dlio_logger() + def finalize(self): + return + + PerfTrace = dftracer() + DFTRACER_ENABLE = False def get_logger( name: str, - level: str = "INFO", - rank_zero_only: Optional[bool] = None, + level: Optional[str] = None, + rank_zero_only: Optional[bool] = True, ) -> logging.Logger: """Returns a `logging.Logger` object. @@ -96,7 +136,9 @@ def get_logger( non-zero ranks (and will be set to `level` on RANK==0). """ logger = logging.getLogger(name) - logger.setLevel(level) + logger.setLevel( + str(level if level is not None else os.environ.get("LOG_LEVEL", "INFO")).upper() + ) if rank_zero_only and ez.get_rank() != 0: logger.setLevel("CRITICAL") return logger @@ -428,6 +470,7 @@ def throughput_calculator(model, args, iteration_time, total_iterations): num_layers = args.num_layers vocab_size = args.padded_vocab_size gqa = args.num_attention_heads // args.num_key_value_heads + num_experts_routed_to = args.topk ffn_multiplier = 3 if args.swiglu else 2 macs_per_flops = 2 @@ -436,7 +479,7 @@ def throughput_calculator(model, args, iteration_time, total_iterations): # correction has been made to TFLOPs formula due to incorrect behavior # observed with selective recompute when GQA not used and for all with GQA seq_len = args.seq_length - if hasattr(args, "actual_seq_length"): + if hasattr(args, 'actual_seq_length'): seq_len = args.actual_seq_length pre_and_post_mha_gemm_macs = ( batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len @@ -451,6 +494,7 @@ def throughput_calculator(model, args, iteration_time, total_iterations): * ffn_hidden_size * hidden_size * seq_len + * num_experts_routed_to ) logit_lmhead_gemm_macs = batch_size * vocab_size * hidden_size * seq_len diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 12a05c5299..3686c6ceeb 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" + import time from typing import Callable from mpi4py import MPI @@ -103,7 +104,7 @@ def model_provider(pre_process=True, post_process=True): with deepspeed_zero_init( data_parallel_group=dpg, remote_device=(None if args.remote_device == "none" else args.remote_device), - config_dict_or_path=args.deepspeed_config_dict, + config_dict_or_path=args.deepspeed_config, # _dict, enabled=args.zero_stage == 3, mpu=mpu, ): diff --git a/train_aGPT_7B.sh b/train_aGPT_7B.sh index a6a2db72ab..1350ea0f2a 100644 --- a/train_aGPT_7B.sh +++ b/train_aGPT_7B.sh @@ -1,4 +1,7 @@ #!/bin/bash --login +#PBS -q lustre_scaling +#PBS -A Aurora_Deployment +#PBS -j oe ##################################### # AuroraGPT-7B @@ -10,25 +13,28 @@ # 1. Navigate into `$PBS_O_WORKDIR` cd "${PBS_O_WORKDIR}" || exit HERE=$(python3 -c 'import os; print(os.getcwd())') && export HERE +GIT_BRANCH=$(git branch --show-current) && export GIT_BRANCH + # 2. source `ALCF/helpers.sh` source "${HERE}/ALCF/helpers.sh" || exit # 3. call `setup` from `./ALCF/helpers.sh` setup "$@" || exit -export run_cmd="${run_cmd}" -echo "${run_cmd}" | tee -a "${OUTPUT_LOG}" +# export run_cmd="${run_cmd}" +echo "${run_cmd[@]}" | tee -a "${OUTPUT_LOG}" # 4. Tell user where to find output printf "[!! %s] View output at:\n %s\n" "$(printBlue "NOTE")" "$(printYellow "${OUTPUT_LOG}")" | tee -a "${OUTPUT_LOG}" -# 5. Ignore the following strings on Intel XPU devices -# (otherwise they'll clutter up logs) -XPU_IGNORE_STRING="CCL_WARN|\ -\ INFO\ \-\ |real_accelerator\.py|numexpr\.utils|async_io|libaio" +# # 5. Ignore the following strings on Intel XPU devices +# # (otherwise they'll clutter up logs) +# XPU_IGNORE_STRING="CCL_WARN|\ -\ INFO\ \-\ |real_accelerator\.py|numexpr\.utils|async_io|libaio" # if [[ $(ezpz_get_machine_name) == "aurora" ]]; then # module unload mpich && module load mpich # fi # # 6. Evaluate ${run_cmd} and append outputs to ${OUTPUT_LOG} -eval "${run_cmd}" |& grep -E -v "${XPU_IGNORE_STRING}" |& tee -a "${OUTPUT_LOG}" +# eval "${run_cmd[@]}" |& tee -a "${OUTPUT_LOG}" +eval "${run_cmd[*]}" |& tee -a "${OUTPUT_LOG}"