Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Recurrently running deepspeed ends up overflowing GPU memory #4222

Closed
shankarp8 opened this issue Aug 25, 2023 · 5 comments
Closed

[BUG] Recurrently running deepspeed ends up overflowing GPU memory #4222

shankarp8 opened this issue Aug 25, 2023 · 5 comments
Assignees
Labels
bug Something isn't working training

Comments

@shankarp8
Copy link

I am conducting research on model editing. Basically, I apply different editing methods to edit a transformer model once on one sample in my dataset, then revert it back to the original (using a deepcopy of the original model) and edit it again on a different sample. Each time, I train the model using deep speed's zero optimizer stage 2. Therefore, this is a different use case than the majority of uses of deepspeed, which only need to train the model once and perform inference once in a given process.

The issue appears to be that deepspeed leaves some residual memory on the GPUs, so every time I attempt to edit it again, there is more and more memory on the GPU until it runs out of memory. I have tried deleting the model_engine each time, clearing torch's CUDA cache, and using python garbage collector but none of these work.

To Reproduce
Any simple training loop using deepspeed should be sufficient to reproduce this error. For me, I used Llama-7B:

from transformers import LlamaForCausalLM, LlamaTokenizer
import copy
import deepspeed
import torch
import subprocess

def get_gpu_memory_usage(): #to illustrate the issue
    result = subprocess.run(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"], stdout=subprocess.PIPE)
    memory_used = result.stdout.decode("utf-8").strip().split("\n")
    return [int(mem) for mem in memory_used]

llama_path = #path where llama is stored

deepspeed_args = {

    "bf16": {
        "enabled":true
    },


    "optimizer": {
        "type": "AdamW",
        "params": {
            "torch_adam": true,
            "lr": 0.00001,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay":  0.01
        }
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": 2,
    "gradient_clipping": 1,
    "train_batch_size": 4,
    "train_micro_batch_size_per_gpu": 2
}

model_original = LlamaForCausalLM.from_pretrained(llama_path).to(device) #set to device of your choice
tokenizer = LlamaTokenizer.from_pretrained(llama_path)


for i in range(10):
      print('GPU MEMORY USAGE AT STEP {}'.format(i), get_gpu_memory_usage())
      model = copy.deepcopy(model_original).to(other_device) #set to other device to prevent it from being on same device as original, for memory reasons
      sentence = 'Arbitrary sentence to fine-tune on' #setting it to the same sentence each time
      inputs = tokenizer(sentence, return_tensors="pt").to(model.device)
      model.train()
      model.resize_token_embeddings(len(tokenizer))
      model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model, 
        model_parameters=model.parameters(),
        config_params=deepspeed_args,
      )
     outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
     loss = outputs.loss
     model_engine.backward(loss)
     model_engine.step()

Expected behavior
After each loop, some memory is left on the GPU, eventually causing it to run out of memory. The 'GPU MEMORY USAGE AT STEP i' print statement should make this clear.

System info (please complete the following information):

  • OS: Ubuntu 18.04.6
  • GPU count and types: 1 machine with 8 NVIDIA A40s
  • Python version: 3.9.12

Launcher context
I run with the python launcher (python program.py). The deepspeed launcher appears to automatically place memory on GPUs and does not allow me to withhold a few visible GPUs to use for other parts of my script (this is necessary).

@shankarp8 shankarp8 added bug Something isn't working training labels Aug 25, 2023
@shankarp8 shankarp8 changed the title [BUG] [BUG] Recurrently running deepspeed ends up overflowing GPU memory Aug 26, 2023
@jomayeri jomayeri self-assigned this Sep 7, 2023
@jomayeri
Copy link
Contributor

jomayeri commented Sep 7, 2023

@shankarp8 If you run this loop without DeepSpeed do you not see the issue?

@shankarp8
Copy link
Author

Well I can't fine-tune Llama-7B on my GPUs (A40s with 48GB RAM) without deepspeed, so I replaced it with GPT2-XL for now. To be precise, the memory at the start of each loop (outputted by the print statement in my code there) is 6GB -> 12GB -> 12GB -> 12GB -> 12GB ... -> 12GB.

By contrast, the GPU memory for each of the ten loops when using deepspeed with GPT2-XL is 6GB -> 12GB -> 15GB -> 18GB -> 21GB -> 24GB -> 27GB -> 30GB -> 33GB. Noticeably, when using deepspeed it seems that 3GB (or, rather, 1/2 of the model size - for Llama-7B it would be 14GB) extra is left on the GPUs each time.

@jomayeri
Copy link
Contributor

@shankarp8 Can you try with this branch #4383 ?

@shankarp8
Copy link
Author

Hi, I tried using that branch and model_engine.destroy() at the end of every loop (let me know if that is what I was supposed to do), and unfortunately it still seems to be having the same issue.

@jomayeri
Copy link
Contributor

After further investigation it looks like we won't be able to clear everything off the GPU by destroying the ZeRO optimizers, but that is the best we can do at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants