Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2-vl-7b-instruct训练vit+porj+llm-lora的内存高 #2479

Open
Lycus99 opened this issue Nov 19, 2024 · 1 comment
Open

qwen2-vl-7b-instruct训练vit+porj+llm-lora的内存高 #2479

Lycus99 opened this issue Nov 19, 2024 · 1 comment

Comments

@Lycus99
Copy link

Lycus99 commented Nov 19, 2024

我的需求是想训练vision encoder,所以选择微调vit+proj,llm部分用了lora。在swift/llm/tuner.py里给llm增加lora之后,增添了如下代码来训练vit和proj:
if not args.freeze_vit:
for n, p in model.named_parameters():
if n.startswith('base_model.model.visual'):
p.requires_grad = True

命令行参数如下,因为看有一个issue里说vision tower对gradient checkpoint支持的不好,就把它关掉了:

MAX_PIXELS=602112 CUDA_VISIBLE_DEVICES=0,1,2,3 NPROC_PER_NODE=4 swift sft
--model_type qwen2-vl-7b-instruct
--model_id_or_path /data/liyc/Code/qwen2-vl-7b-instruct
--model_revision master
--sft_type lora
--tuner_backend peft
--template_type AUTO
--dtype AUTO
--output_dir /data/liyc/Output/qwen2-vl-7b/swift_lora_vit
--dataset /data/liyc/Dataset/swift_pmv_it_swinir_50k.json
--deepspeed default-zero3
--train_dataset_sample -1
--num_train_epochs 3
--max_length 600
--check_dataset_strategy warning
--lora_rank 8
--lora_alpha 32
--lora_dropout_p 0.05
--gradient_checkpointing false
--weight_decay 0.1
--learning_rate 1e-4
--per_device_train_batch_size 1
--gradient_accumulation_steps 16
--per_device_eval_batch_size 1
--max_grad_norm 0.5
--warmup_ratio 0.03
--eval_steps 500
--save_steps 100
--save_total_limit 2
--logging_steps 10
--use_flash_attn true
--deepspeed zero3-offload
--freeze_vit false

在训练中,torch=2.2.0,cuda=11.8,四卡h20。训练参数为695.94M,swift输出的gpu显存从最初的32g增加到80.8g。全参数+freeze vit不是才4*60g么,为什么这个显存消耗这么高?

微信截图_20241119210052

@Lycus99
Copy link
Author

Lycus99 commented Nov 20, 2024

我设置 --gradient_checkpointing true,显存确实降下来的,但是为什么会持续增加?
微信截图_20241120094817

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant