Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression - Phi3 has graph breaks in 4.48 but not in 4.47.1 #35716

Open
4 tasks
kshitij12345 opened this issue Jan 15, 2025 · 8 comments
Open
4 tasks

Regression - Phi3 has graph breaks in 4.48 but not in 4.47.1 #35716

kshitij12345 opened this issue Jan 15, 2025 · 8 comments
Labels

Comments

@kshitij12345
Copy link
Contributor

System Info

  • transformers version: 4.48.0
  • Platform: Linux-6.8.0-48
  • Python version: 3.12.3
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.6.0
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: NVIDIA RTX 6000 Ada Generation

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoConfig, AutoModelForCausalLM

cfg = AutoConfig.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
cfg.num_hidden_layers = 2
with torch.device("cuda"):
    m = AutoModelForCausalLM.from_config(cfg)

def backend(gm, sample_args):
    # gm.print_readable()
    print("SUBGRAPH")
    return gm

m.model = torch.compile(m.model, backend=backend)

input_ids = torch.randint(0, 100, (1, 4096), device="cuda")
m(input_ids)

For 4.48, we see 4 subgraphs while with previous 4.47.1 we see only 1 subgraph.

Running with TORCH_LOGS="graph_breaks" prints

V0115 16:09:58.933000 510381 torch/_dynamo/symbolic_convert.py:444] [1/0] [__graph_breaks] Graph break (details suppressed) in user code at /usr/local/lib/python3.12/dist-packages/transformers/models/phi3/modeling_phi3.py:386
V0115 16:09:58.933000 510381 torch/_dynamo/symbolic_convert.py:444] [1/0] [__graph_breaks] Reason: Unsupported: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
V0115 16:09:58.945000 510381 torch/_dynamo/symbolic_convert.py:444] [2/0] [__graph_breaks] Graph break (details suppressed) in user code at /usr/local/lib/python3.12/dist-packages/transformers/models/phi3/modeling_phi3.py:386
V0115 16:09:58.945000 510381 torch/_dynamo/symbolic_convert.py:444] [2/0] [__graph_breaks] Reason: Data-dependent jump

Expected behavior

Should have a single subgraph ideally like before.

@ArthurZucker
Copy link
Collaborator

Will have a look asap!

@ArthurZucker
Copy link
Collaborator

I cannot reproduce this!

Image

@kshitij12345
Copy link
Contributor Author

You have repro'ed it correctly as we can see SUBGRAPH being printed 4 times (previously there weren't graph breaks and it was only printed once as only 1 subgraph was created).

Alternatively, you can put the repro code in a script test.py and run it with TORCH_LOGS="graph_breaks" python test.py. You will see the log from torch.compile.

@Rocketknight1
Copy link
Member

Line 386 is this block:

if hasattr(self.config, "original_max_position_embeddings"):
    original_max_position_embeddings = self.config.original_max_position_embeddings
else:
    original_max_position_embeddings = self.config.max_position_embeddings

I would have expected this to be compilable, but possibly hasattr or attribute access on self.config causes a graph break.

@kshitij12345, if you're comfortable, can you try:

  1. Cloning transformers
  2. Installing that local clone in editable mode with pip install -e .
  3. Edit modeling_phi3.py and delete those 4 lines, and instead set self.original_max_embeddings as a module attribute in the module init, then read that attribute in _longrope_frequency_update().
  4. Let us know if that resolves the graph breaks!

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jan 16, 2025

Hey all! I've been looking at this issue for a while and identified the points of graph breaks. The two main changes are manual device movement here

self.original_inv_freq = self.original_inv_freq.to(device)

and the fact that we stopped inferring seq length from inputs shape (below), rather we infer it from position_ids values. Both of the points are related to our latest refactor and moving the RoPE embedding calculation to the base class

seq_len = torch.max(position_ids) + 1

cc'ing also @Cyrilvallez for RoPE refactoring

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Jan 16, 2025

Thanks for the ping and investigation @zucchini-nlp! For the first issue, making original_inv_freq a buffer of the model could be a workaround, as the device would be automatically moved around when calling .to(), thus won't change directly in the forward passes.
For second point, using the cache as was done before to infer the seq_len should work as it is DynamicCache and only rely on .shape (would break as well with StaticCache tough).
I can open a PR if you did not already! 🤗

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Jan 16, 2025

Second point is related to what I faced in #35681 in order to never break graph

@zucchini-nlp
Copy link
Member

@Cyrilvallez great, thanks for proposing solutions. Feel free to submit a PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants