Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2VL 2B & 7B OOM #1390

Open
schopra8 opened this issue Dec 6, 2024 · 2 comments
Open

Qwen2VL 2B & 7B OOM #1390

schopra8 opened this issue Dec 6, 2024 · 2 comments

Comments

@schopra8
Copy link

schopra8 commented Dec 6, 2024

When fine-tuning a Qwen2 model on an A100 (80GB), I get OOMs.

This is surprising given batch size of 1, small images (256 x 256), and 4-bit training. With the same data, it's possible to train LLAMA3 11B with batch size of 8 and only 15 GB of memory consumed.

from typing import Any, List
from unsloth import FastVisionModel
import pickle
import pandas as pd
from transformers import TrainerCallback
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
import pickle
import os
import argparse

LLAMA3_11B = 'llama3_11b'
QWEN2_2B = 'qwen2_2b'
QWEN2_7B = 'qwen2_7b'
MODEL_NAME_TO_MODEL_PATH = {LLAMA3_11B: "unsloth/Llama-3.2-11B-Vision-Instruct", QWEN2_2B: 'unsloth/Qwen2-VL-2B-Instruct-bnb-4bit', QWEN2_7B:'unsloth/Qwen2-VL-7B-Instruct-bnb-4bit'}

def train(model: str, training_dataset: List[Any],
          validation_dataset: List[Any]):
    # Load the model and tokenizer
    model_path = MODEL_NAME_TO_MODEL_PATH[model]
    print(f'Training: {model_path}')
    model, tokenizer = FastVisionModel.from_pretrained(
        model_path,
        load_in_4bit=True,
        use_gradient_checkpointing="unsloth",
    )
    model = FastVisionModel.get_peft_model(
        model,
        finetune_vision_layers=True,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
        r=16,
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
    )

    FastVisionModel.for_training(model)  # Enable for training!

    # Initialize the callback
    loss_logger = LossLoggerCallback()

    # Initialize the trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=UnslothVisionDataCollator(model, tokenizer),
        train_dataset=training_dataset,
        eval_dataset=validation_dataset,
        callbacks=[loss_logger],  # Add the custom callback
        args=SFTConfig(
            per_device_train_batch_size=1,
            gradient_accumulation_steps=1,
            warmup_steps=5,
            num_train_epochs=6.3,
            learning_rate=2e-4,
            fp16=not is_bf16_supported(),
            bf16=is_bf16_supported(),
            logging_steps=1,
            optim="adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3407,
            output_dir="outputs",
            report_to="none",
            remove_unused_columns=False,
            dataset_text_field="",
            dataset_kwargs={"skip_prepare_dataset": True},
            dataset_num_proc=8,
            max_seq_length=2048,
            disable_tqdm=False,
            dataloader_pin_memory=True,
            dataloader_persistent_workers=True,
            dataloader_prefetch_factor=2,
            dataloader_num_workers=8,
            packing=False),
    )

    # Train the model
    trainer.train()

    # Save the model and tokenizer
    model.save_pretrained("lora_model")
    tokenizer.save_pretrained("lora_model")
    model.save_pretrained_merged("unsloth_finetune", tokenizer)
    

class LossLoggerCallback(TrainerCallback):
    """
    Callback to log and store training loss at each step.
    """
    def __init__(self):
        self.training_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs:
            self.training_losses.append({"step": state.global_step, "loss": logs["loss"]})

if __name__ == "__main__":
    # Argument parser for configurable prefix
    parser = argparse.ArgumentParser(description="Add prefix to image paths in the dataset.")
    parser.add_argument(
        "--model",
        choices=[LLAMA3_11B, QWEN2_2B, QWEN2_7B],
        default=LLAMA3_11B,
        help=f"Choose between '{LLAMA3_11B}' (default), '{QWEN2_2B}', and '{QWEN2_7B}."
    )
    args = parser.parse_args()

    # Load the pickle files
    with open('unsloth_train_fine_tuning.pkl', "rb") as p:
        training_dataset = pickle.load(p)
    with open('unsloth_validation_fine_tuning.pkl', "rb") as p:
        validation_dataset = pickle.load(p)

    # Train the model
    train(args.model, training_dataset, validation_dataset)

Traceback:

Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 25,597 | Num Epochs = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient Accumulation steps = 1
\        /    Total batch size = 8 | Total steps = 20,160
 "-____-"     Number of trainable parameters = 28,950,528
🦥 Unsloth needs about 1-3 minutes to load everything - please wait!
  0%|                                                                                                                                                                                                                                   | 1/20160 [00:37<207:26:18, 37.04s/it]
Traceback (most recent call last):
  File "/home/sahil/my_folder/unsloth_fine_tuning_vllm.py", line 173, in <module>
    train(args.model, training_dataset, validation_dataset)
  File "/home/sahil/my_folder/unsloth_fine_tuning_vllm.py", line 97, in train
    trainer.train()
  File "<string>", line 157, in train
  File "<string>", line 381, in _fast_inner_training_loop
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3653, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3709, in compute_loss
    outputs = model(**inputs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/peft/peft_model.py", line 1644, in forward
    return self.base_model(
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sahil/my_folder/unsloth_compiled_cache/unsloth_compiled_module_qwen2_vl.py", line 1204, in forward
    return Qwen2VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw,
rope_deltas, cache_position, **loss_kwargs)
File "/home/sahil/my_folder/unsloth_compiled_cache/unsloth_compiled_module_qwen2_vl.py", line 903, in Qwen2VLForConditionalGeneration_forward
    image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1039, in forward
    hidden_states = self._gradient_checkpointing_func(
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/unsloth_zoo/gradient_checkpointing.py", line 360, in forward
    outputs = run_function(*args)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 431, in forward
    hidden_states = hidden_states + self.attn(
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/image-data-processing-twpiHuhd-py3.10/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sahil/my_folder/unsloth_compiled_cache/unsloth_compiled_module_qwen2_vl.py", line 409, in forward
    return VisionSdpaAttention_forward(self, hidden_states, cu_seqlens, rotary_pos_emb)
  File "/home/sahil/my_folder/unsloth_compiled_cache/unsloth_compiled_module_qwen2_vl.py", line 393, in VisionSdpaAttention_forward
    attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 272.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 60.47 GiB is free. Including non-PyTorch memory, this process has 18.66 GiB memory in use. Of the allocated memory 18.02 GiB is allocated by PyTorc
h, and 138.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cu
da.html#environment-variables)
@darkacorn
Copy link

darkacorn commented Dec 6, 2024

can someone reproduce ? before we investigate further ? edd maybe ?

@InspectorCaracal
Copy link

InspectorCaracal commented Dec 11, 2024

I'm also having an extremely unexpected OOM error when trying to train with Qwen. I'm using a free colab T4 for continued pretraining, which I've done using Mistral 7B with no problems in the past.

But trying to do the same for Qwen 2.5 3B is running out of memory....

Memory use before I run trainer.train():

GPU = Tesla T4. Max memory = 14.748 GB.
6.383 GB of memory reserved.

Result of the trainer.train() step:

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 259 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 8
\        /    Total batch size = 16 | Total steps = 16
 "-____-"     Number of trainable parameters = 861,798,400

Unsloth: Setting lr = 5.00e-06 instead of 5.00e-05 for embed_tokens.
Unsloth: Setting lr = 5.00e-06 instead of 5.00e-05 for lm_head.

[ 2/16 : < :, Epoch 0.06/1]
Step 	Training Loss

---------------------------------------------------------------------------

OutOfMemoryError                          Traceback (most recent call last)

<ipython-input-9-3d62c575fcfd> in <cell line: 1>()
----> 1 trainer_stats = trainer.train()

13 frames

/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py in _convert_to_fp32(tensor)
    780 
    781     def _convert_to_fp32(tensor):
--> 782         return tensor.float()
    783 
    784     def _is_fp16_bf16_tensor(tensor):

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.38 GiB. GPU 0 has a total capacity of 14.75 GiB of which 1.77 GiB is free. Process 282100 has 12.98 GiB memory in use. Of the allocated memory 12.53 GiB is allocated by PyTorch, and 301.43 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

This was with a max sequence length of only 2048 - when I initially tried it with a higher value, it gave me OOM on the exact same line of code, but without getting far enough to output the Step Training loss header before doing so.

P.S. I also want to add that I was using triton==2.3.1 for this, as UnslothTrainer was giving me this error. Just in case that's relevant here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants