Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchtune QAT + compile raises error #1427

Open
felipemello1 opened this issue Dec 17, 2024 · 3 comments
Open

Torchtune QAT + compile raises error #1427

felipemello1 opened this issue Dec 17, 2024 · 3 comments
Assignees

Comments

@felipemello1
Copy link

Hi all, when finetuning torchtune QAT with torchao API + compile, I get an error and a warning:

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora compile=True

error:

Running with torchrun...
W1217 12:03:37.193000 859018 site-packages/torch/distributed/run.py:793] 
W1217 12:03:37.193000 859018 site-packages/torch/distributed/run.py:793] *****************************************
W1217 12:03:37.193000 859018 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1217 12:03:37.193000 859018 site-packages/torch/distributed/run.py:793] *****************************************
Running QATLoRAFinetuneRecipeDistributed with resolved config:
batch_size: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
  checkpoint_files:
  - model.safetensors
  model_type: LLAMA3_2
  output_dir: /tmp/torchtune/llama3_2_1B/qat_lora
  recipe_checkpoint: null
compile: true
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/llama3_2_1B/qat_lora/logs
model:
  _component_: torchtune.models.llama3_2.lora_llama3_2_1b
  apply_lora_to_mlp: true
  lora_alpha: 128
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 64
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/torchtune/llama3_2_1B/qat_lora
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_2_1B/qat_lora/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
NCCL version 2.21.5+cuda12.4
Setting manual seed to local seed 685700170. Local seed is seed + rank = 685700170 + 0
Writing logs to /tmp/torchtune/llama3_2_1B/qat_lora/logs/log_1734465832.txt
FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
/data/users/felipemello/torchtune/torchtune/training/quantization.py:178: UserWarning: *QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead
  warn(
/data/users/felipemello/torchtune/torchtune/training/quantization.py:178: UserWarning: *QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead
  warn(
Compiling model layers with torch.compile...
Instantiating model and loading checkpoint took 1.74 secs
Memory stats after model init:
        GPU peak memory allocation: 1.24 GiB
        GPU peak memory reserved: 1.26 GiB
        GPU peak memory active: 1.24 GiB
Optimizer is initialized.
Compiling loss with torch.compile...
Loss is initialized.
Dataset and Sampler are initialized.
Learning rate scheduler is initialized.
 Profiling disabled.
 Profiler config after instantiation: {'enabled': False}
  0%|                                                                                                                                                          | 0/808 [00:00<?, ?it/s][rank1]: Traceback (most recent call last):
[rank1]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 971, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:              ^^^^^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 966, in recipe_main
[rank1]:     recipe.train()
[rank1]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 836, in train
[rank1]:     logits = self._model(**batch)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 635, in forward
[rank1]:     h = layer(
[rank1]:         ^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1734, in _wrapped_call_impl
[rank1]:     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 82, in forward
[rank1]:     def forward(
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/attention.py", line 181, in forward
[rank1]:     def forward(
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/attention.py", line 234, in torch_dynamo_resume_in_forward_at_234
[rank1]:     q = self.q_proj(x)
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/peft/lora.py", line 241, in forward
[rank1]:     def forward(self, x: torch.Tensor) -> torch.Tensor:
[rank1]:   File "/data/users/felipemello/torchtune/torchtune/modules/peft/lora.py", line 250, in torch_dynamo_resume_in_forward_at_250
[rank1]:     _x = self.activation_fake_quantizer(x)
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
[rank1]:     return compiled_fn(full_args)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 308, in runtime_wrapper
[rank1]:     all_outs = call_func_at_runtime_with_args(
[rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:                             ^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 98, in g
[rank1]:     return f(*args)
[rank1]:            ^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1525, in forward
[rank1]:     fw_outs = call_func_at_runtime_with_args(
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:                             ^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
[rank1]:     return compiled_fn(runtime_args)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
[rank1]:     outs = compiled_fn(args)
[rank1]:            ^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
[rank1]:     return self.current_callable(inputs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_inductor/utils.py", line 1977, in run
[rank1]:     return model(new_inputs)
[rank1]:            ^^^^^^^^^^^^^^^^^
[rank1]:   File "/tmp/torchinductor_felipemello/iw/ciwy5odtye2qhnp7wp3ucziyofvpy53da7pze22kpcgwqgsk6jgd.py", line 238, in call
[rank1]:     assert_size_stride(primals_2, (2048, 2048), (2048, 1))
[rank1]: AssertionError: expected size 512==2048, stride 2048==2048 at dim=0
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 971, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 966, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py", line 836, in train
[rank0]:     logits = self._model(**batch)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 635, in forward
[rank0]:     h = layer(
[rank0]:         ^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1734, in _wrapped_call_impl
[rank0]:     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 82, in forward
[rank0]:     def forward(
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/attention.py", line 181, in forward
[rank0]:     def forward(
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/attention.py", line 234, in torch_dynamo_resume_in_forward_at_234
[rank0]:     q = self.q_proj(x)
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/peft/lora.py", line 241, in forward
[rank0]:     def forward(self, x: torch.Tensor) -> torch.Tensor:
[rank0]:   File "/data/users/felipemello/torchtune/torchtune/modules/peft/lora.py", line 250, in torch_dynamo_resume_in_forward_at_250
[rank0]:     _x = self.activation_fake_quantizer(x)
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
[rank0]:     return compiled_fn(full_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 308, in runtime_wrapper
[rank0]:     all_outs = call_func_at_runtime_with_args(
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank0]:     out = normalize_as_list(f(args))
[rank0]:                             ^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 98, in g
[rank0]:     return f(*args)
[rank0]:            ^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1525, in forward
[rank0]:     fw_outs = call_func_at_runtime_with_args(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank0]:     out = normalize_as_list(f(args))
[rank0]:                             ^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
[rank0]:     return compiled_fn(runtime_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
[rank0]:     outs = compiled_fn(args)
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
[rank0]:     return self.current_callable(inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/_inductor/utils.py", line 1977, in run
[rank0]:     return model(new_inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/tmp/torchinductor_felipemello/dc/cdchmrs3lk3evxm3dwqo7om2bsaf4ss7klkndh2f6udf3hhk7dy2.py", line 238, in call
[rank0]:     assert_size_stride(primals_2, (2048, 2048), (2048, 1))
[rank0]: AssertionError: expected size 512==2048, stride 2048==2048 at dim=0
  0%|                                                                                                                                                          | 0/808 [00:06<?, ?it/s]
[rank0]:[W1217 12:04:04.677995780 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
W1217 12:04:05.712000 859018 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 869967 closing signal SIGTERM
E1217 12:04:05.877000 859018 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 869968) of binary: /home/felipemello/.conda/envs/vllm/bin/python
Traceback (most recent call last):
  File "/home/felipemello/.conda/envs/vllm/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/data/users/felipemello/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/data/users/felipemello/torchtune/torchtune/_cli/run.py", line 212, in _run_cmd
    self._run_distributed(args, is_builtin=is_builtin)
  File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/_cli/run.py", line 101, in _run_distributed
    run(args)
  File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/vllm/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/data/users/felipemello/torchtune/recipes/qat_lora_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-12-17_12:04:05
  host      : devgpu018.nha2.facebook.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 869968)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

warning:

/data/users/felipemello/torchtune/torchtune/training/quantization.py:178: UserWarning: *QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead

The warning is on torchtune side, but i am not sure if its accurate, and QAT should be updated, or the warning should be removed.

@andrewor14 do you mind taking a look?

cc: @ebsmothers

@felipemello1
Copy link
Author

I don't remember, but i think that we had a similar issue in the past with NF4 and QLoRA and the solution was to change the kernel size, but not 100% sure.

@andrewor14
Copy link
Contributor

Warning is triggered too easily, I need to fix it. Will look into the error

@andrewor14 andrewor14 self-assigned this Dec 17, 2024
@andrewor14
Copy link
Contributor

Warning fixed in: pytorch/torchtune#2174

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants