Skip to content

Commit

Permalink
Unload vllm model
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Oct 1, 2024
1 parent 58b6d98 commit 2247ce4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
13 changes: 13 additions & 0 deletions rl/llm/engines/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import gc
import math
import os
import socket
Expand Down Expand Up @@ -377,7 +378,19 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
import torch
import torch.distributed
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel,
)

LOGGER.info("Unloading VLLM model from GPU memory...")
destroy_model_parallel()
del self.vllm
gc.collect()
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()
LOGGER.info("VLLM model unloaded.")

def generate(self, prompt: InferenceInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]
Expand Down
6 changes: 6 additions & 0 deletions rl/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def write_csv(


def download(url: str, dest: str | Path) -> None:
"""Download a file from a URL to a destination path, with a progress bar.
Args:
url: The URL to download from.
dest: The destination path.
"""
dest = Path(dest)
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
Expand Down

0 comments on commit 2247ce4

Please sign in to comment.