Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training crashed early into run: CUDA error: invalid argument #1387

Open
selalipop opened this issue Dec 5, 2024 · 2 comments
Open

Training crashed early into run: CUDA error: invalid argument #1387

selalipop opened this issue Dec 5, 2024 · 2 comments

Comments

@selalipop
Copy link

Minimal script:

os.environ["WANDB_PROJECT"] = "project-name"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints


max_seq_length = 14500
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = validation_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_eval_batch_size = 1,
        eval_accumulation_steps = 4,
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 8,
        warmup_steps = 10,
        num_train_epochs = 3, # Set this for 1 full training run.
        learning_rate=1e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        save_steps = 120,
        eval_steps = 120,
        evaluation_strategy="steps",
        do_eval=True, 
        weight_decay = 0.1,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "output",
        report_to = "wandb",
        run_name = "run-name",
    ),
)

import gc
torch.cuda.empty_cache()
gc.collect()


trainer_stats = trainer.train()

Stacktrace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 4
      2 os.environ["WANDB_PROJECT"] = "project-mistral"  # name your W&B project
      3 os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
----> 4 trainer_stats = trainer.train()

File <string>:157, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File <string>:380, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:31, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File [/usr/local/lib/python3.11/dist-packages/unsloth/models/_utils.py:1028](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/unsloth/models/_utils.py#line=1027), in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
   1026     pass
   1027 pass
-> 1028 return self._old_compute_loss(model, inputs, *args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/transformers/trainer.py:3633](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/transformers/trainer.py#line=3632), in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1736](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1735), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1747](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1746), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py:823](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py#line=822), in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py:811](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py#line=810), in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))

File [/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py:44](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py#line=43), in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/_compile.py:32](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_compile.py#line=31), in _disable_dynamo.<locals>.inner(*args, **kwargs)
     29     disable_fn = torch._dynamo.disable(fn, recursive)
     30     fn.__dynamo_disable = disable_fn
---> 32 return disable_fn(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:632](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py#line=631), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    630 prior = _maybe_set_eval_frame(callback)
    631 try:
--> 632     return fn(*args, **kwargs)
    633 finally:
    634     _maybe_set_eval_frame(prior)

File [/usr/local/lib/python3.11/dist-packages/unsloth/models/llama.py:1084](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/unsloth/models/llama.py#line=1083), in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
   1069 @torch._disable_dynamo
   1070 def PeftModelForCausalLM_fast_forward(
   1071     self,
   (...)
   1082     **kwargs,
   1083 ):
-> 1084     return self.base_model(
   1085         input_ids=input_ids,
   1086         causal_mask=causal_mask,
   1087         attention_mask=attention_mask,
   1088         inputs_embeds=inputs_embeds,
   1089         labels=labels,
   1090         output_attentions=output_attentions,
   1091         output_hidden_states=output_hidden_states,
   1092         return_dict=return_dict,
   1093         num_logits_to_keep=num_logits_to_keep,
   1094         **kwargs,
   1095     )

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1736](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1735), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1747](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1746), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py:197](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py#line=196), in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py:170](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py#line=169), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.11/dist-packages/unsloth/models/mistral.py:220](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/unsloth/models/mistral.py#line=219), in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
    212     outputs = LlamaModel_fast_forward_inference(
    213         self,
    214         input_ids,
   (...)
    217         attention_mask = attention_mask,
    218     )
    219 else:
--> 220     outputs = self.model(
    221         input_ids=input_ids,
    222         causal_mask=causal_mask,
    223         attention_mask=attention_mask,
    224         position_ids=position_ids,
    225         past_key_values=past_key_values,
    226         inputs_embeds=inputs_embeds,
    227         use_cache=use_cache,
    228         output_attentions=output_attentions,
    229         output_hidden_states=output_hidden_states,
    230         return_dict=return_dict,
    231     )
    232 pass
    234 hidden_states = outputs[0]

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1736](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1735), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1747](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1746), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py:170](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py#line=169), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [/usr/local/lib/python3.11/dist-packages/unsloth/models/llama.py:791](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/unsloth/models/llama.py#line=790), in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    788 pass
    790 if offloaded_gradient_checkpointing:
--> 791     hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
    792         decoder_layer,
    793         hidden_states,
    794         mask,
    795         attention_mask,
    796         position_ids,
    797         past_key_values,
    798         output_attentions,
    799         use_cache,
    800     )[0]
    802 elif gradient_checkpointing:
    803     def create_custom_forward(module):

File [/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py:575](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py#line=574), in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         "In order to use an autograd.Function with functorch transforms "
    580         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    581         "staticmethod. For more details, please see "
    582         "https://pytorch.org/docs/main/notes/extending.func.html"
    583     )

File /usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py:465, in custom_fwd.<locals>.decorate_fwd(*args, **kwargs)
    463 if cast_inputs is None:
    464     args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
--> 465     return fwd(*args, **kwargs)
    466 else:
    467     autocast_context = torch.is_autocast_enabled(device_type)

File [/usr/local/lib/python3.11/dist-packages/unsloth_zoo/gradient_checkpointing.py:154](https://3g4f9ha7sn2wcl-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/unsloth_zoo/gradient_checkpointing.py#line=153), in Unsloth_Offloaded_Gradient_Checkpointer.forward(ctx, forward_function, hidden_states, *args)
    151 @staticmethod
    152 @torch_amp_custom_fwd
    153 def forward(ctx, forward_function, hidden_states, *args):
--> 154     saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
    155     with torch.no_grad():
    156         output = forward_function(hidden_states, *args)

RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Running on RTX 6000 Ada, VRAM usage was ~60%

@micahr234
Copy link

I got the same thing

@danielhanchen
Copy link
Contributor

Apologies on the delay - it's most likely out of system RAM usage (not GPU VRAM usage) - since Unsloth offloads to system RAM - I would reduce the batch size per_device_train_batch_size = 4 to something smaller

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants