Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WMT slower in Pytorch than Jax #467

Closed
priyakasimbeg opened this issue Aug 8, 2023 · 19 comments · Fixed by #597
Closed

WMT slower in Pytorch than Jax #467

priyakasimbeg opened this issue Aug 8, 2023 · 19 comments · Fixed by #597
Assignees
Labels
⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads

Comments

@priyakasimbeg
Copy link
Contributor

WMT pytorch is currently slower than Jax.

This bug is intended to at least document possible causes.

@runame @pomonam could you please summarize current findings and possible solutions?

@priyakasimbeg priyakasimbeg changed the title WMT slower Jax vs Pytorch WMT slower in Pytorch than Jax Aug 8, 2023
@priyakasimbeg priyakasimbeg added the 🚀 Launch Blocker Issues that are blocking launch of benchmark label Aug 9, 2023
@runame
Copy link
Contributor

runame commented Aug 10, 2023

The main cause of the speed difference seems to be the update step, which is ~20% faster in Jax. The data loading is also faster in Jax, but the difference is insignificant in absolute terms. Juhan and me suspect that the model code is responsible for the slowdown. I will try to 1) use some new optimized functions for transformers from PyTorch 2 and 2) rewrite the masking to be compatible with torch.compile.

@msaroufim
Copy link
Member

msaroufim commented Aug 11, 2023

FWIW I did try using a HF transformer implementation with boolean masks and that torch.compile'd just fine https://gist.github.com/msaroufim/946d7d26e89bab0bfe83d9929b533701

@msaroufim
Copy link
Member

msaroufim commented Aug 12, 2023

@runame I managed to torch.compile() the wmt model just now an A10G using pytorch nightliesm after removing the disable from submission_runner.py

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 --force-reinstall

I don't have perf numbers yet since it feels stuck at Translating evaluation dataset (this probably needs some progress bar since it's been running on my end for a few hours now)- do you mind trying it out on your end as well and lmk if you hit any new errors?

(sam) ubuntu@ip-172-31-9-217:~/algorithmic-efficiency$ python submission_runner.py     --framework=pytorch     --workload=wmt     --experiment_dir=$HOME/experiments     --experiment_name="experiment_$(date +"%Y%m%d_%H%M%S")"     --submission_path=reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py     --tuning_search_space=reference_algorithms/development_algorithms/wmt/tuning_search_space.json
2023-08-12 01:52:23.302616: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
I0812 01:52:28.113030 140124175877952 logger_utils.py:76] Creating experiment directory at /home/ubuntu/experiments/experiment_20230812_015222/wmt_pytorch.
W0812 01:52:28.115083 140124175877952 xla_bridge.py:636] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0812 01:52:28.116768 140124175877952 submission_runner.py:478] Using RNG seed 710879992
I0812 01:52:28.117297 140124175877952 submission_runner.py:487] --- Tuning run 1/1 ---
I0812 01:52:28.117359 140124175877952 submission_runner.py:492] Creating tuning directory at /home/ubuntu/experiments/experiment_20230812_015222/wmt_pytorch/trial_1.
I0812 01:52:28.117454 140124175877952 logger_utils.py:92] Saving hparams to /home/ubuntu/experiments/experiment_20230812_015222/wmt_pytorch/trial_1/hparams.json.
I0812 01:52:28.117842 140124175877952 submission_runner.py:176] Initializing dataset.
I0812 01:52:28.117906 140124175877952 submission_runner.py:183] Initializing model.
W0812 01:52:29.207352 140124175877952 submission_runner.py:198] These workloads cannot be fully compiled under current PyTorch version. Proceeding without `torch.compile`.
W0812 01:52:29.207496 140124175877952 submission_runner.py:202] Compiling model with `torch.compile`.
I0812 01:52:30.740139 140124175877952 submission_runner.py:205] Initializing optimizer.
I0812 01:52:30.741168 140124175877952 submission_runner.py:212] Initializing metrics bundle.
I0812 01:52:30.741234 140124175877952 submission_runner.py:230] Initializing checkpoint and logger.
I0812 01:52:30.741771 140124175877952 logger_utils.py:257] Unable to record workload.train_mean information. Continuing without it.
I0812 01:52:30.741840 140124175877952 logger_utils.py:257] Unable to record workload.train_stddev information. Continuing without it.
I0812 01:52:30.797685 140124175877952 submission_runner.py:251] Saving meta data to /home/ubuntu/experiments/experiment_20230812_015222/wmt_pytorch/trial_1/meta_data_0.json.
I0812 01:52:30.798029 140124175877952 submission_runner.py:254] Saving flags to /home/ubuntu/experiments/experiment_20230812_015222/wmt_pytorch/trial_1/flags_0.json.
I0812 01:52:30.817061 140124175877952 submission_runner.py:264] Starting training loop.
I0812 01:52:30.824819 140124175877952 dataset_info.py:578] Load dataset info from /home/ubuntu/data/wmt17_translate/de-en/1.0.0
I0812 01:52:30.863615 140124175877952 logging_logger.py:49] Constructing tf.data.Dataset wmt17_translate for split train, from /home/ubuntu/data/wmt17_translate/de-en/1.0.0
/opt/conda/envs/sam/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
I0812 01:52:55.268288 140124175877952 spec.py:320] Evaluating on the training split.
I0812 01:52:55.270218 140124175877952 dataset_info.py:578] Load dataset info from /home/ubuntu/data/wmt17_translate/de-en/1.0.0
I0812 01:52:55.295589 140124175877952 logging_logger.py:49] Constructing tf.data.Dataset wmt17_translate for split train, from /home/ubuntu/data/wmt17_translate/de-en/1.0.0
/opt/conda/envs/sam/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/opt/conda/envs/sam/lib/python3.10/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/opt/conda/envs/sam/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/opt/conda/envs/sam/lib/python3.10/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
I0812 01:53:26.027593 140124175877952 workload.py:131] Translating evaluation dataset.
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(133)translate_and_calculate_bleu()
-> references, predictions = [], []
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(134)translate_and_calculate_bleu()
-> for _ in range(num_batches):
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(135)translate_and_calculate_bleu()
-> pred_batch = next(ds_iter)
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(136)translate_and_calculate_bleu()
-> inputs = pred_batch['inputs']
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(137)translate_and_calculate_bleu()
-> targets = pred_batch['targets']
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(138)translate_and_calculate_bleu()
-> predicted = self.predict_step(inputs,
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(139)translate_and_calculate_bleu()
-> params,
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(140)translate_and_calculate_bleu()
-> decode.EOS_ID,
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(141)translate_and_calculate_bleu()
-> max_predict_length)
(Pdb) predicted
*** NameError: name 'predicted' is not defined
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(138)translate_and_calculate_bleu()
-> predicted = self.predict_step(inputs,
(Pdb) n
n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(144)translate_and_calculate_bleu()
-> weights = pred_batch.get('weights')
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(145)translate_and_calculate_bleu()
-> if weights is not None:
(Pdb) n
> /home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py(146)translate_and_calculate_bleu()
-> actual_batch_size = weights.sum(0)[0].item()
(Pdb) predicted
tensor([[ 8374,  8374,  8374,  ..., 16011, 16011, 16011],
        [24599, 24599, 24599,  ..., 24599, 24599, 24599],
        [24599, 24599, 24599,  ..., 24599, 24599, 24599],
        ...,
        [10620, 10620, 10620,  ...,  1361,  1361,  1361],
        [19729, 19729, 19729,  ...,  8328,  8328,  8328],
        [24599, 24599, 24599,  ...,  1847,  1847,  1847]], device='cuda:0',
       dtype=torch.int32)
(Pdb) 

@pomonam
Copy link
Contributor

pomonam commented Aug 12, 2023

@msaroufim Thank you so much for looking into this! Yes, I will try this and see if there are any new issues.

@runame
Copy link
Contributor

runame commented Aug 12, 2023

@msaroufim @pomonam Thanks a lot for investigating! I won't get to working on this before Sunday or Monday, will check this thread for updates then.

@pomonam
Copy link
Contributor

pomonam commented Aug 14, 2023

@runame One thing I wanted to try out (but did not have time to do) was to use the PyTorch default functions (e.g., multi_head_attention_forward) when don't need cache - training. I am not sure how useful this might be, but I think it is worth a try. Also, with torch.compile, the graph seems to break many times (and the timings don't change much), but I am not sure how to fix this.

@msaroufim
Copy link
Member

msaroufim commented Aug 14, 2023

Oh interesting so you can confirm that it does compile but you're getting graph breaks? Are you running fullgraph=True?

@priyakasimbeg priyakasimbeg added the ⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads label Aug 22, 2023
@runame
Copy link
Contributor

runame commented Aug 23, 2023

@msaroufim I just tried running with fullgraph=True and got the following errors.

Branch dev (also with changes from #489) + nightly + fullgraph=True: I get the same error @pomonam has reported here.
With torch.compile moved before wrapping the model with DDP + as above:

Traceback (most recent call last):
  File "submission_runner.py", line 624, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 595, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 520, in score_submission_on_workload
    timing, metrics = train_once(workload, global_batch_size,
  File "submission_runner.py", line 299, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/baselines/adamw/pytorch/submission.py", line 74, in update_params
    logits_batch, new_model_state = workload.model_fn(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py", line 213, in model_fn
    logits_batch = model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 636, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 453, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2162, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 833, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/distributed.py", line 436, in compile_fn
    submod_compiler.run(*example_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/distributed.py", line 417, in run_node
    compiled_submod_real = self.compile_submod(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/distributed.py", line 361, in compile_submod
    self.compiler(input_mod, args),
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
TypeError: _convert_frame_assert() missing 2 required positional arguments: 'hooks' and 'frame_state'

Without fullgraph=True it seems to compile, the first update step is completed and evaluation starts. However, with the changes from #489 and nightly, I get an error (see below) at some point during eval, specifically during translation. With stable OR without the changes from #489 this error does not occur.

Traceback (most recent call last):
  File "submission_runner.py", line 330, in train_once
    latest_eval_result = workload.eval_model(global_eval_batch_size,
  File "/algorithmic-efficiency/algorithmic_efficiency/spec.py", line 334, in eval_model
    validation_metrics = self._eval_model_on_split(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/workload.py", line 165, in _eval_model_on_split
    metrics = self.eval_step(params, eval_batch)
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py", line 307, in eval_step
    logits, _ = self.model_fn(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py", line 212, in model_fn
    logits_batch = model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 636, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 453, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2162, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 833, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/distributed.py", line 436, in compile_fn
    submod_compiler.run(*example_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/distributed.py", line 430, in run_node
    return curr_submod(*new_args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 678, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 284, in __call__
    raise e
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 274, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.1005", line 26, in forward
    l__self___encoder_encoder_layers_0_self_attn_in_proj = self.L__self___encoder_encoder_layers_0_self_attn_in_proj(l__self___encoder_encoder_layers_0_norm1);  l__self___encoder_encoder_layers_0_norm1 = None
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

@runame
Copy link
Contributor

runame commented Aug 23, 2023

Branch dev (also with changes from #489, also with torch.compile before DDP) + stable + fullgraph=True:

Traceback (most recent call last):
  File "submission_runner.py", line 624, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 595, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 520, in score_submission_on_workload
    timing, metrics = train_once(workload, global_batch_size,
  File "submission_runner.py", line 299, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/baselines/adamw/pytorch/submission.py", line 74, in update_params
    logits_batch, new_model_state = workload.model_fn(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py", line 212, in model_fn
    logits_batch = model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py", line 178, in forward
    memory = self.encoder(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 334, in catch_errors
    return hijacked_callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 372, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 541, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised TypeError: _convert_frame_assert() missing 1 required positional argument: 'hooks'

@msaroufim
Copy link
Member

msaroufim commented Aug 23, 2023

Well it is promising that torch.compile() with inductor works - if the failure is only with fullgraph=True try out torch.compile(m, backend="explain") which should help debug

If I could get access to your VM to run experiments quickly that might help as well @priyakasimbeg - can chat more in person tomorrow

EDIT: This is a legit error actually, FWIW i created an mlperf label on pytorch side to track all open issues

@runame
Copy link
Contributor

runame commented Aug 24, 2023

@msaroufim Is the 'explain' backend supposed to be available with the current nightly (2.1.0.dev20230824+cu118)? I get an error and torch._dynamo.list_backends() only lists ['cudagraphs', 'inductor', 'onnxrt', 'tvm'].

@msaroufim
Copy link
Member

msaroufim commented Aug 24, 2023

Hmm it should also be usable by torch._dynamo.explain() it should list out all the reasons for graph breaks

@runame
Copy link
Contributor

runame commented Aug 29, 2023

@msaroufim After the refactor in #489 the only remaining graph breaks are of this type:

Reason: DDPOptimizer intentional graph-break (See Note [DDPOptimizer]).

After setting torch._dynamo.config.optimize_ddp = False there are no graph breaks left and I'm still getting this error.

@msaroufim
Copy link
Member

@runame I chatted with Will Constable about this and his point is the DDPOptimizer will always give you graph breaks so if you have it enabled you won't be able to do fullgraph=True either

Do either you or @pomonam have a smaller repro of the linked error here #487 (comment)

@runame
Copy link
Contributor

runame commented Aug 31, 2023

@msaroufim Not sure if it's useful, but I have created a smaller repro using the same model here. It runs successfully on a single GPU and fails with DDP.

Update: I decided to not follow up on this because this issue is currently not blocking us.

@runame
Copy link
Contributor

runame commented Sep 6, 2023

@msaroufim Since we don't know why we get basically no speed improvements from using torch.compile despite using a standard transformer architecture and having no graph breaks, I have two other questions:

  1. Could you take a look at the model code (from PR Refactor of MultiheadAttention module in PyTorch WMT workload #489) and check if there are any obvious inefficiencies? I think the only non-standard part is the creation of the masks, but I have already timed a run without them to exclude them as the bottleneck and they seem to make almost no difference to the runtime.
  2. Maybe we misattributed the slow-down to the forward and backward pass due to a bug in our timing code and it is actually caused by something else, e.g. the data loading. You can find the code of the profiler that we use to time the different parts of training here. Essentially, we use torch.cuda.synchronize() together with time.time() or time.monotonic() for timing.

@wconstab
Copy link

torch._dynamo.exc.BackendCompilerFailed: compile_fn raised TypeError: _convert_frame_assert() missing 1 required positional argument: 'hooks'

@runame @msaroufim Not sure if this is relevant anymore (if you're unblocked some other way) but this specific issue about 'hooks' from the DDPOptimizer should have been fixed on master by pytorch/pytorch#107834.

@priyakasimbeg priyakasimbeg removed the 🚀 Launch Blocker Issues that are blocking launch of benchmark label Sep 12, 2023
@priyakasimbeg
Copy link
Contributor Author

@BoyuanFeng is working on this. Compiling the loss function in addition to the model seems to significantly speed up WMT.

@priyakasimbeg
Copy link
Contributor Author

Resolved in #597 after torch.compiling the loss functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants