Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM on training Mistral hypernet #8

Open
kdcyberdude opened this issue Jun 10, 2024 · 2 comments
Open

OOM on training Mistral hypernet #8

kdcyberdude opened this issue Jun 10, 2024 · 2 comments

Comments

@kdcyberdude
Copy link

Hi @bminixhofer, I am getting OOM with following logs when training a mistral multilingual hypernet. I have tried on this two A100(80GB) as well. Not sure what is wrong!!

I have created a branch containing a script to reproduce this. You can run the ./install script on any instance of vast.ai - main...kdcyberdude:zett:main

2024-06-10 00:21:14.077507: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.00GiB (rounded to 1073741824)requested by op                                                                         
2024-06-10 00:21:14.078739: W external/tsl/tsl/framework/bfc_allocator.cc:497] ****************************************************************************************************                                                                              
2024-06-10 00:21:14.078795: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                             
BufferAssignment OOM Debugging.                                                                                                                                                                                                                                  
BufferAssignment stats:                                                                                                                                                                                                                                          
             parameter allocation:    1.00GiB                                                                                                                                                                                                                    
              constant allocation:         0B                                                                                                                                                                                                                    
        maybe_live_out allocation:    1.00GiB                                                                                                                                                                                                                    
     preallocated temp allocation:         0B                                                                                                                                                                                                                    
                 total allocation:    2.00GiB                                                                                                                                                                                                                    
              total fragmentation:         0B (0.00%)                                                                                                                                                                                                            
Peak buffers:                                                                                                                                                                                                                                                    
        Buffer 1:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                                                                                                                                                                                    
                ==========================                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                 
        Buffer 2:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                XLA Label: fusion                                                                                                                                                                                                                                
                Shape: pred[1,1,32768,32768]                                                                                                                                                                                                                     
                ==========================                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                 
  0%|                                                                                                                                                                                                                                 | 0/100000 [00:37<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                                                                                               
  File "/workspace/zett/train.py", line 1625, in <module>                                                                                                                                                                                                        
    main()                                                                                                                                                                                                                                                       
  File "/workspace/zett/train.py", line 1526, in main                                                                                                                                                                                                            
    state, train_metric = current_step_fn(state, batch)                                                                                                                                                                                                          
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                          
  File "/workspace/zett/train.py", line 1203, in train_step                                                                                                                                                                                                      
    (loss, (lexical_loss, mean_lexical_overlap)), grad = grad_fn(state.params)                                                                                                                                                                                   
                                                         ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                   
  File "/workspace/zett/train.py", line 1116, in compute_loss                                                                                                                                                                                                    
    ) = compute_embeddings_and_logits(                                                                                                                                                                                                                           
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                           
  File "/workspace/zett/train.py", line 1092, in compute_embeddings_and_logits                                                                                                                                                                                   
    logits = model_fn(                                                                                                                                                                                                                                           
             ^^^^^^^^^                                                                                                                                                                                                                                           
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 502, in __call__                                                                                                                           
    outputs = self.module.apply(                                                                                                                                                                                                                                 
              ^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                 
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 677, in __call__                                                                                                                           
    outputs = self.model(                                                                                                                                                                                                                                        
              ^^^^^^^^^^^                                                                                                                                                                                                                                        
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 605, in __call__                                                                                                                           
    outputs = self.layers(                                                                                                                                                                                                                                       
              ^^^^^^^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 556, in __call__                                                                                                                           
    layer_outputs = block(                                                                                                                                                                                                                                       
                    ^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 374, in __call__                                                                                                                           
    outputs = self.self_attn(                                                                                                                                                                                                                                    
              ^^^^^^^^^^^^^^^                                                                                                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 241, in setup                                                                                                                              
    casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")                                                                                                                                                    
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 810, in make_causal_mask                                                                                                                                                
    return make_attention_mask(                                                                                                                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                  
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 786, in make_attention_mask                                           
    mask = jnp.expand_dims(mask, axis=-3)                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 912, in expand_dims                                               
    return lax.expand_dims(a, axis)                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^                                            
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                                       
BufferAssignment OOM Debugging.                                                
BufferAssignment stats:                                                        
             parameter allocation:    1.00GiB                                  
              constant allocation:         0B                                  
        maybe_live_out allocation:    1.00GiB                                  
     preallocated temp allocation:         0B                                  
                 total allocation:    2.00GiB                                  
              total fragmentation:         0B (0.00%)                          
Peak buffers:                                                                  
        Buffer 1:                                                                                
                Size: 1.00GiB                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                    
                ==========================                                                       

        Buffer 2:                                                                                
                Size: 1.00GiB                                                                    
                XLA Label: fusion                                                                
                Shape: pred[1,1,32768,32768]                                                                           
                ==========================   
@bminixhofer
Copy link
Owner

Hi, it looks like it is trying to create a very large causal mask due to the high max_position_embeddings. You can try manually lowering the max_position_embeddings to the block_size, which should make it a lot easier on memory (and should be safe to do).

@kdcyberdude
Copy link
Author

Hi @bminixhofer, Do I need to update max_position_embedding while initializing roberta-base model to 128 in zett/model/init.py

I tried without using pretrained hypernet model as well. It's still giving OOM.

And what is the VRAM requirement for training this on GPU?

Logs

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1718174206.615502 474103 hlo_rematerialization.cc:2946] Can't reduce memory use below -18.34GiB (-19688159409 bytes) by rematerialization; only reduced to 21.55GiB (23140745816 bytes), down from 21.55GiB (23140745816 bytes) originally
2024-06-12 12:06:47.545238: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 16.50GiB
constant allocation: 22B
maybe_live_out allocation: 21.55GiB
preallocated temp allocation: 13.4KiB
total allocation: 38.05GiB
total fragmentation: 13.4KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 1000.00MiB
Entry Parameter Subshape: f32[32000,8192]
==========================

    Buffer 2:
            Size: 1000.00MiB
            XLA Label: fusion
            Shape: f32[32000,8192]
            ==========================

    Buffer 3:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 4:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 5:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 6:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 7:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 8:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 9:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 10:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 11:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 12:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 13:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 14:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 15:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

Traceback (most recent call last):
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 198, in _run_module_as_main
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 88, in _run_code
exec(code, run_globals)
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
cli.main()
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="main")
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/mnt/pi/proj/jun/zett/train.py", line 1629, in
main()
File "/mnt/pi/proj/jun/zett/train.py", line 848, in main
state = jax.jit(
^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 16.50GiB
constant allocation: 22B
maybe_live_out allocation: 21.55GiB
preallocated temp allocation: 13.4KiB
total allocation: 38.05GiB
total fragmentation: 13.4KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 1000.00MiB
Entry Parameter Subshape: f32[32000,8192]
==========================

    Buffer 2:
            Size: 1000.00MiB
            XLA Label: fusion
            Shape: f32[32000,8192]
            ==========================

    Buffer 3:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 4:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 5:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 6:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 7:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 8:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 9:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 10:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 11:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 12:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 13:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 14:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 15:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

PS: I am new to JAX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants