Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Conformer OOMS some times #497

Closed
priyakasimbeg opened this issue Aug 22, 2023 · 18 comments
Closed

Pytorch Conformer OOMS some times #497

priyakasimbeg opened this issue Aug 22, 2023 · 18 comments
Labels
🔥 PyTorch Issue that mainly deals with the PyTorch version of the code

Comments

@priyakasimbeg
Copy link
Contributor

Pytorch conformer occasionally OOMS.

Description

Traceback:

Traceback (most recent call last):
  File "submission_runner.py", line 624, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 595, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 520, in score_submission_on_workload
    timing, metrics = train_once(workload, global_batch_size,
  File "submission_runner.py", line 299, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/baselines/adamw/pytorch/submission.py", line 99, in update_params
    loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.51 GiB. GPU 5 has a total capacty of 15.78 GiB of which 765.44 MiB is free. Process 11976 has 15.03 GiB memory in use. Of the allocated memory 6.25 GiB is allocated by PyTorch, and 6.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Steps to Reproduce

Pytorch version: torch.dev08202023

torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py --framework=pytorch --workload=librispeech_conformer --submission_path=baselines/adamw/pytorch/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=tests/regression_tests/adamw --overwrite=True --save_checkpoints=False --max_global_steps=10 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab --torch_compile=true
@priyakasimbeg priyakasimbeg added the 🔥 PyTorch Issue that mainly deals with the PyTorch version of the code label Aug 22, 2023
@msaroufim
Copy link
Member

@janeyx99 have you seen this kind of non deterministic optimizer OOM?

@janeyx99
Copy link

janeyx99 commented Aug 23, 2023

this doesn’t look like it’s ooming in the optimizer but rather the backward, no?

the fact that it’s nondeterministic is definitely weird…meaning the extra memory can come from anywhere. otherwise i would guess that activations checkpointing would help here.

@chandramouli-sastry
Copy link
Contributor

chandramouli-sastry commented Aug 26, 2023

I ran the following script which trains the conformer model for 1000 steps and repeats it 10 times to see how often it errors out. All 10 times, the model trained successfully without any error. I ran this inside docker built using the dockerfile on main branch. @priyakasimbeg , could you please let me know if this following script errs out on the VM you got the above error?

command = """torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py --framework=pytorch --workload=librispeech_conformer --submission_path=baselines/adamw/pytorch/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json --data_dir=/data/work_dir/data/ --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=tests/regression_tests/adamw --overwrite=True --save_checkpoints=False --max_global_steps=1000 --librispeech_tokenizer_vocab_path=/data/spm_model.vocab --torch_compile=true"""

import os
for i in range(10):
    code = os.system(command)
    print(code)```

@priyakasimbeg
Copy link
Contributor Author

This is not reproducible anymore after Juhan's fixes in #502 for Criteo1tb memory issues. I believe clearing the cache after evals helped.

@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Sep 18, 2023

Happened again on git commit 4c38ffb at step 1778 on kasimbeg-2

@priyakasimbeg priyakasimbeg reopened this Sep 18, 2023
@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Sep 21, 2023

Here is the traceback: https://gist.github.com/priyakasimbeg/35a7e2562ed471aba6d8087da1e65fda.

Seems like it happens in the backward pass:

 File "/algorithmic-efficiency/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py", line 64, in update_params
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py", line 64, in update_params
        loss.backward()loss.backward()

        loss.backward()loss.backward()  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward


  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
    loss.backward()    
loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
        loss.backward()  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
loss.backward()

  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
                torch.autograd.backward(torch.autograd.backward(    torch.autograd.backward(torch.autograd.backward(

torch.autograd.backward(

    
torch.autograd.backward(  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
      File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward

  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
    torch.autograd.backward(  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
 File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
                Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cudatorch.cudatorch.cudatorch.cuda    ..    ..torch.cudaOutOfMemoryErrorVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passOutOfMemoryErrorVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passtorch.cudaOutOfMemoryError: OutOfMemoryError: ..
: CUDA out of memory. Tried to allocate 2.51 GiB (GPU 4; 15.78 GiB total capacity; 6.26 GiB already allocated; 541.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONFOutOfMemoryError
: OutOfMemoryError
CUDA out of memory. Tried to allocate 2.51 GiB (GPU 3; 15.78 GiB total capacity; 6.26 GiB already allocated; 541.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF: CUDA out of memory. Tried to allocate 2.51 GiB (GPU 1; 15.78 GiB total capacity; 6.26 GiB already allocated; 601.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF: 
CUDA out of memory. Tried to allocate 2.51 GiB (GPU 6; 15.78 GiB total capacity; 6.26 GiB already allocated; 541.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
CUDA out of memory. Tried to allocate 2.51 GiB (GPU 2; 15.78 GiB total capacity; 6.26 GiB already allocated; 541.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONFCUDA out of memory. Tried to allocate 2.51 GiB (GPU 0; 15.78 GiB total capacity; 6.26 GiB already allocated; 565.44 MiB free; 13.63 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

As discussed in torch dev meeting will try to turn off fused / foreach in the nadamw optimizer.

@janeyx99
Copy link

Oh hey, I’m not sure how changing the optimizer would affect peak memory if the OOM is happening in the backward—could you give some rationale on why? Also, I’m not sure you’re aware already, but a memory snapshot of one training loop would be helpful with debugging. Zach Devito’s blog describes how to capture one https://zdevito.github.io/2022/12/09/memory-traces.html and I’m happy to send you a code snippet if you’re interested in capturing a snapshot.

@priyakasimbeg
Copy link
Contributor Author

Hey @janeyx99 thanks for taking a look at this!
You're right, I don't think this will address the memory in the backward pass but I thought (maybe wrongly so) that it could reduce the overall memory consumption. Turning off the optimizations didn't help though.
And sure a code snippet to capture a trace would be super helpful!

@janeyx99
Copy link

janeyx99 commented Sep 22, 2023

oh ya! here’s a snippet where you surround the code you want to profile:

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history()

# train
train(model, optimizer)

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
    dump(s, f)

Then you can drag the snapshot.pickle file to https://zdevito.github.io/assets/trace.html to see the visualization of active memory usage over time

@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Sep 22, 2023

@janeyx99 I ran the profiler. The workload ooms after ~4 steps.
I can't really see anything useful in the plot though, it seems like the memory usage is very high from the start.
image

Also, when I hover over the blocks I see 'block was allocated before _record_history was enabled' even though I moved the torch.cuda.memory._record_memory_history(enabled=True) statement to the beginning of the main function in submission_runner.py.
Are there any special considerations for the DDP setting maybe?

@priyakasimbeg
Copy link
Contributor Author

@msaroufim I tried capturing a trace using a single GPU and torchrun process but I still see 'block was allocated before _record_history was enabled' for most of the blocks. Do you know how I could get some more useful information out of the profiler?

@priyakasimbeg
Copy link
Contributor Author

Tuning max_split_size_mb to 512 seems to have fixed this. Will send out PR

@janeyx99
Copy link

ah, changing the max_split_size_mb should help with fragmentation.

your memory profile staying flat seems to mean that you enabled the profile when no allocations were made, which is odd because running forward and backward and optimizer should need intermediates at the very least. the distributedness may have something to do with it—maybe this is profiling the wrong device…

@priyakasimbeg
Copy link
Contributor Author

On second thought it looks like reducing the max_split_size_mb to 256 increases the submission time by 2x. @janeyx99 I also tried setting the device to 0 and running on a single process instead of the multiple DDP processes but got the same graph.

@chandramouli-sastry
Copy link
Contributor

Adjusting the scaled-dot-product-attention backend to use the math backend (slower than other alternatives available but it seems to be similar to the default implementation in 1.13) removes OOM errors. I tried several other things to fix this but they weren't effective without this adjustment. I will adjust the traindiffs test accordingly and test this change before creating a PR for this.

@priyakasimbeg
Copy link
Contributor Author

Fixed in #549.
Thanks @chandramouli-sastry!!!

@lessw2020
Copy link
Contributor

Hi @priyakasimbeg - is there any additional info related to this being re-opened?

@priyakasimbeg
Copy link
Contributor Author

Upgrading the GPU driver to 535.104.05 seems to resolve the CUDA OOM, so we will upgrade the drivers on the competition hardware and mark this as resolved.

We also confirmed per recommendation form @lessw2020 that on pytorch 2.1.0 setting the following option resolves the OOM:

if torch.cuda.is_available():
    torch.cuda.memory._set_allocator_settings('expandable_segments:True')

We won't have to use this flag after all with the driver update but just want to document this in case we run into issues in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🔥 PyTorch Issue that mainly deals with the PyTorch version of the code
Projects
None yet
Development

No branches or pull requests

5 participants