Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ValueError: Tensors must be contiguous when using deepspeed.initialize #6571

Open
shadow150519 opened this issue Sep 25, 2024 · 2 comments
Labels
bug Something isn't working training

Comments

@shadow150519
Copy link

Describe the bug
I'm running the BingBertSquad example in DeepSpeedExamples/training/BingBertSquad, I use the google-bert/bert-large-uncased model from hugging face. I use bash run_squad_deepspeed.sh 4 ckpt/bert_large_uncased/pytorch_model.bin /dataset /output to finetune model. In order to use HF ckpt, I slightly change the run_squad_deepspeed.sh

#~/bin/bash

#1: number of GPUs
#2: Model File Address
#3: BertSquad Data Directory Address
#4: Output Directory Address

NGPU_PER_NODE=$1
MODEL_FILE=$2
SQUAD_DIR=$3
OUTPUT_DIR=$4
LR=${5:-0.00003}
SEED=${6:-12345}
MASTER_PORT=${7:-29500}
DROPOUT=${8:-0.1}
echo "lr is ${LR}"
echo "seed is $SEED"
echo "master port is $MASTER_PORT"
echo "dropout is ${DROPOUT}"

# Force deepspeed to run with only local node
NUM_NODES=1
HOSTFILE=/dev/null

NGPU=$((NGPU_PER_NODE*NUM_NODES))
EFFECTIVE_BATCH_SIZE=24
MAX_GPU_BATCH_SIZE=3
PER_GPU_BATCH_SIZE=$((EFFECTIVE_BATCH_SIZE/NGPU))
if [[ $PER_GPU_BATCH_SIZE -lt $MAX_GPU_BATCH_SIZE ]]; then
       GRAD_ACCUM_STEPS=1
else
       GRAD_ACCUM_STEPS=$((PER_GPU_BATCH_SIZE/MAX_GPU_BATCH_SIZE))
fi
JOB_NAME="deepspeed_${NGPU}GPUs_${EFFECTIVE_BATCH_SIZE}batch_size"
config_json=deepspeed_bsz24_config.json
# original cmd for DS ckpt
# run_cmd="deepspeed --num_nodes ${NUM_NODES} --num_gpus ${NGPU_PER_NODE} \
#        --master_port=${MASTER_PORT} \
#        --hostfile ${HOSTFILE} \
#        nvidia_run_squad_deepspeed.py \
#        --bert_model bert-large-uncased \
#        --do_train \
#        --do_lower_case \
#        --predict_batch_size 3 \
#        --do_predict \
#        --train_file $SQUAD_DIR/train-v1.1.json \
#        --predict_file $SQUAD_DIR/dev-v1.1.json \
#        --train_batch_size $PER_GPU_BATCH_SIZE \
#        --learning_rate ${LR} \
#        --num_train_epochs 2.0 \
#        --max_seq_length 384 \
#        --doc_stride 128 \
#        --output_dir $OUTPUT_DIR \
#        --job_name ${JOB_NAME} \
#        --gradient_accumulation_steps ${GRAD_ACCUM_STEPS} \
#        --fp16 \
#        --deepspeed \
#        --deepspeed_config ${config_json} \
#        --dropout ${DROPOUT} \
#        --model_file $MODEL_FILE \
#        --seed ${SEED} \
#        --preln \
#        "

run_cmd="deepspeed --num_nodes ${NUM_NODES} --num_gpus ${NGPU_PER_NODE} \
       --master_port=${MASTER_PORT} \
       --hostfile ${HOSTFILE} \
       nvidia_run_squad_deepspeed.py \
       --bert_model bert-large-uncased \
       --do_train \
       --do_lower_case \
       --predict_batch_size 3 \
       --do_predict \
       --train_file $SQUAD_DIR/train-v1.1.json \
       --predict_file $SQUAD_DIR/dev-v1.1.json \
       --train_batch_size $PER_GPU_BATCH_SIZE \
       --learning_rate ${LR} \
       --num_train_epochs 2.0 \
       --max_seq_length 384 \
       --doc_stride 128 \
       --output_dir $OUTPUT_DIR \
       --job_name ${JOB_NAME} \
       --gradient_accumulation_steps ${GRAD_ACCUM_STEPS} \
       --fp16 \
       --deepspeed \
       --deepspeed_config ${config_json} \
       --dropout ${DROPOUT} \
       --model_file $MODEL_FILE \
       --seed ${SEED} \
       --ckpt_type HF \ # use HF ckpt
       --origin_bert_config_file ckpt/bert_large_uncased/config.json # use specific config
       "

echo ${run_cmd}
eval ${run_cmd}

however, I end up with the following error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/wtx/workspace/python_project/DeepSpeedExamples/training/BingBertSquad/nvidia_run_squad_deepspeed.py", line 1165, in <module>
[rank0]:     main()
[rank0]:   File "/home/wtx/workspace/python_project/DeepSpeedExamples/training/BingBertSquad/nvidia_run_squad_deepspeed.py", line 878, in main
[rank0]:     model, optimizer, _, _ = deepspeed.initialize(
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank0]:     engine = DeepSpeedEngine(args=args,
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 269, in __init__
[rank0]:     self._configure_distributed_model(model)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1201, in _configure_distributed_model
[rank0]:     self._broadcast_model()
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1120, in _broadcast_model
[rank0]:     dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/comm/comm.py", line 224, in broadcast
[rank0]:     return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 200, in broadcast
[rank0]:     return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 79, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2209, in broadcast
[rank0]:     work = group.broadcast([tensor], opts)
[rank0]: ValueError: Tensors must be contiguous
[2024-09-25 17:46:35,221] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 4166482
[2024-09-25 17:46:35,221] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 4166483
[2024-09-25 17:46:35,327] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 4166484
[2024-09-25 17:46:35,593] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 4166485
[2024-09-25 17:46:35,606] [ERROR] [launch.py:325:sigkill_handler] ['/home/wtx/miniconda3/envs/llm/bin/python', '-u', 'nvidia_run_squad_deepspeed.py', '--local_rank=3', '--bert_model', 'bert-large-uncased', '--do_train', '--do_lower_case', '--predict_batch_size', '3', '--do_predict', '--train_file', './dataset//train-v1.1.json', '--predict_file', './dataset//dev-v1.1.json', '--train_batch_size', '6', '--learning_rate', '0.00003', '--num_train_epochs', '2.0', '--max_seq_length', '384', '--doc_stride', '128', '--output_dir', './output/', '--job_name', 'deepspeed_4GPUs_24batch_size', '--gradient_accumulation_steps', '2', '--fp16', '--deepspeed', '--deepspeed_config', 'deepspeed_bsz24_config.json', '--dropout', '0.1', '--model_file', './weights/bert_large_uncased/pytorch_model.bin', '--seed', '12345', '--ckpt_type', 'HF', '--origin_bert_config_file', 'weights/bert_large_uncased/config.json'] exits with return code = 1

To Reproduce
Steps to reproduce the behavior:
1.download hf bert-large-uncased model and put into ckpt/bert_large_uncased folders
2.run bash run_squad_deepspeed.sh 4 ckpt/bert_large_uncased/pytorch_model.bin /dataset /output

Expected behavior

ds_report output
Please run ds_report to give us details about your setup.

(llm) (base) wtx@gpu04:~/workspace/python_project/DeepSpeedExamples/training/BingBertSquad$ ds_report 
[2024-09-25 17:41:07,170] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/torch']
torch version .................... 2.4.0+cu121
deepspeed install path ........... ['/home/wtx/miniconda3/envs/llm/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.15.1, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 1.13, cuda 11.7
shared memory (/dev/shm) size .... 503.68 GB

System info (please complete the following information):

  • OS: e.g. Ubuntu 20.04
  • GPU count and types: one machine with 4 A40
  • Python version: 3.10.4
  • Any other relevant info about your setup
  • pytorch: 2.4.0
  • deepspeed: 0.15.1
@shadow150519 shadow150519 added bug Something isn't working training labels Sep 25, 2024
@shadow150519
Copy link
Author

I find a similar issue #2736 in deepspeed and #94907 in pytorch

@championsnet
Copy link

championsnet commented Oct 2, 2024

I faced similar issues when I upgraded the transformers package from 4.28.0 to 4.44.2 (for some reason). Once I reverted back and kept torch==2.2.1, lightning==2.2.1, accelerate==0.27.2 and deepspeed==0.14.0 the error went away. As to why it works like that, it is anyone's guess.

Note that I am using python 3.11

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants