Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.compile #1

Open
3 tasks
lapp0 opened this issue Aug 24, 2024 · 1 comment
Open
3 tasks

Fix torch.compile #1

lapp0 opened this issue Aug 24, 2024 · 1 comment

Comments

@lapp0
Copy link
Owner

lapp0 commented Aug 24, 2024

Reproducer

import distily

distily.run.benchmark(
    teacher_model_name_or_path="gpt2",
    output_dir="distily_verify_compile",
    hub_model_id="distily/distily_verify_compile",
    push_to_hub=True,
    report_to="tensorboard",
    dataset_sample_size=4000,
    gradient_accumulation_steps=1,
    harness_benchmarks=[],
    params=[
        {"torch_compile": True},
        {"torch_compile": False},
    ]
)

Error

Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 66, in benchmark
    res = train(*parsed_args_tuple)
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 86, in train
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 92, in train
    train_output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2205, in _inner_training_loop
    self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 135, in evaluate
    super().evaluate(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 4075, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 103, in compute_loss
    loss_dict = self.distillation_objective(self.teacher_model, model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 106, in __call__
    logits_loss = self._calc_loss(out_s.logits, out_t.logits, self.logits_loss_component, device)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 135, in _calc_loss
    loss = loss_component.get_loss(feat_s, feat_t)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/loss.py", line 47, in kl_divergence_loss
    teacher_prob = F.softmax(feat_t, dim=-1)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 1885, in softmax
    ret = input.softmax(dim)
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Implications

Completion of this issue allows us to benchmark and integrate

@lapp0
Copy link
Owner Author

lapp0 commented Sep 9, 2024

mobiusml/hqq#108

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant