Skip to content

Commit

Permalink
Added better GPU memory management
Browse files Browse the repository at this point in the history
  • Loading branch information
Dicklesworthstone committed May 28, 2024
1 parent 8e2ed18 commit 37e5c63
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 36 deletions.
46 changes: 24 additions & 22 deletions service_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logger_config import setup_logger
import shared_resources
from shared_resources import load_model, text_completion_model_cache, is_gpu_available
from shared_resources import load_model, text_completion_model_cache, is_gpu_available, evict_model_from_gpu
from database_functions import AsyncSessionLocal, execute_with_retry
from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text, suppress_stdout_stderr, image_to_base64_data_uri, process_image, find_clip_model_path
from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, AudioTranscript
Expand Down Expand Up @@ -483,38 +483,40 @@ def load_text_completion_model(llm_model_name: str, raise_http_exception: bool =
matching_files.sort(key=os.path.getmtime, reverse=True)
model_file_path = matching_files[0]
is_llava_multimodal_model = 'llava' in llm_model_name and 'mmproj' not in llm_model_name
chat_handler = None # Determine the appropriate chat handler based on the model name
chat_handler = None
if 'llava' in llm_model_name:
clip_model_path = find_clip_model_path(llm_model_name)
if clip_model_path is None:
raise FileNotFoundError
chat_handler = Llava16ChatHandler(clip_model_path=clip_model_path)
with suppress_stdout_stderr():
gpu_info = is_gpu_available()
if gpu_info:
num_gpus = gpu_info['num_gpus']
if num_gpus > 1:
llama_split_mode = 2 # 2, // split rows across GPUs | 1, // split layers and KV across GPUs
else:
llama_split_mode = 0
else:
num_gpus = 0
model_instance = Llama(
model_path=model_file_path,
embedding=True if is_llava_multimodal_model else False,
n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS,
flash_attn=USE_FLASH_ATTENTION,
verbose=USE_VERBOSE,
llama_split_mode=llama_split_mode,
n_gpu_layers=-1 if gpu_info['gpu_found'] else 0,
clip_model_path=clip_model_path if is_llava_multimodal_model else None,
chat_handler=chat_handler
)
llama_split_mode = 2 if gpu_info and gpu_info['num_gpus'] > 1 else 0
while True:
try:
model_instance = Llama(
model_path=model_file_path,
embedding=True if is_llava_multimodal_model else False,
n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS,
flash_attn=USE_FLASH_ATTENTION,
verbose=USE_VERBOSE,
llama_split_mode=llama_split_mode,
n_gpu_layers=-1 if gpu_info['gpu_found'] else 0,
clip_model_path=clip_model_path if is_llava_multimodal_model else None,
chat_handler=chat_handler
)
break
except ValueError as e:
if "cudaMalloc failed: out of memory" in str(e):
evict_model_from_gpu()
else:
raise
text_completion_model_cache[llm_model_name] = model_instance
shared_resources.loaded_models[llm_model_name] = model_instance
return model_instance
except TypeError as e:
logger.error(f"TypeError occurred while loading the model: {e}")
logger.error(traceback.format_exc())
logger.error(traceback.format_exc())
raise
except Exception as e:
logger.error(f"Exception occurred while loading the model: {e}")
Expand Down
52 changes: 38 additions & 14 deletions shared_resources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes
from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes, suppress_stdout_stderr
from database_functions import DatabaseWriter, initialize_db, AsyncSessionLocal, delete_expired_rows
from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks
from logger_config import setup_logger
Expand All @@ -17,11 +17,13 @@
from decouple import config
from fastapi import HTTPException
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from collections import OrderedDict

logger = setup_logger()

embedding_model_cache = {} # Model cache to store loaded models
text_completion_model_cache = {} # Model cache to store loaded text completion models
embedding_model_cache = OrderedDict() # Model cache to store loaded models with LRU eviction
text_completion_model_cache = OrderedDict() # Model cache to store loaded text completion models with LRU eviction
loaded_models = OrderedDict() # Track loaded models to manage GPU memory

SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str)
Expand Down Expand Up @@ -111,9 +113,8 @@ async def initialize_globals():
redis = None
lock_manager = None


def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
download_status = []
download_status = []
json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
if not os.path.exists(json_path):
initial_model_urls = [
Expand Down Expand Up @@ -149,7 +150,7 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
status = {"url": url, "status": "success", "message": "File already exists."}
filename = os.path.join(models_dir, model_name_with_extension)
try:
with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure
with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure
if not os.path.exists(filename):
logger.info(f"Downloading model {model_name_with_extension} from {url}...")
urllib.request.urlretrieve(url, filename)
Expand All @@ -172,6 +173,12 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
logger.info("Model downloads completed.")
return model_names, download_status

def evict_model_from_gpu():
if loaded_models:
evicted_model_name, evicted_model_instance = loaded_models.popitem(last=False)
del evicted_model_instance
logger.info(f"Evicted model {evicted_model_name} from GPU memory")

def load_model(llm_model_name: str, raise_http_exception: bool = True):
global USE_VERBOSE
model_instance = None
Expand All @@ -186,16 +193,33 @@ def load_model(llm_model_name: str, raise_http_exception: bool = True):
matching_files.sort(key=os.path.getmtime, reverse=True)
model_file_path = matching_files[0]
gpu_info = is_gpu_available()
if 'llava' in llm_model_name:
is_llava_multimodal_model = 1
else:
is_llava_multimodal_model = 0
if not is_llava_multimodal_model:
if gpu_info['gpu_found']:
model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE, n_gpu_layers=-1) # Load the model with GPU acceleration
is_llava_multimodal_model = 'llava' in llm_model_name
with suppress_stdout_stderr():
if is_llava_multimodal_model:
pass
else:
model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE) # Load the model without GPU acceleration
while True:
try:
if gpu_info['gpu_found']:
model_instance = llama_cpp.Llama(
model_path=model_file_path, embedding=True,
n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS,
verbose=USE_VERBOSE, n_gpu_layers=-1
) # Load the model with GPU acceleration
else:
model_instance = llama_cpp.Llama(
model_path=model_file_path, embedding=True,
n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS,
verbose=USE_VERBOSE
) # Load the model without GPU acceleration
break
except ValueError as e:
if "cudaMalloc failed: out of memory" in str(e):
evict_model_from_gpu()
else:
raise
embedding_model_cache[llm_model_name] = model_instance
loaded_models[llm_model_name] = model_instance
return model_instance
except TypeError as e:
logger.error(f"TypeError occurred while loading the model: {e}")
Expand Down

0 comments on commit 37e5c63

Please sign in to comment.