Skip to content

Commit

Permalink
Merge pull request #422 from jhj0517/feature/add-offload
Browse files Browse the repository at this point in the history
Feature/add offload
  • Loading branch information
jhj0517 authored Dec 13, 2024
2 parents 7c91344 + be45659 commit 95d4f9a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions modules/diarize/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import time
import logging
import gc

from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
Expand Down Expand Up @@ -121,6 +122,16 @@ def update_pipe(self,
)
logger.disabled = False

def offload(self):
"""Offload the model and free up the memory"""
if self.pipe is not None:
del self.pipe
self.pipe = None
if self.device == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()

@staticmethod
def get_device():
if torch.cuda.is_available():
Expand Down
10 changes: 10 additions & 0 deletions modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import gradio as gr
from abc import ABC, abstractmethod
import gc
from typing import List
from datetime import datetime

Expand Down Expand Up @@ -128,6 +129,15 @@ def translate_file(self,
finally:
self.release_cuda_memory()

def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
gc.collect()

@staticmethod
def get_device():
if torch.cuda.is_available():
Expand Down
10 changes: 10 additions & 0 deletions modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from datetime import datetime
from faster_whisper.vad import VadOptions
import gc

from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
Expand Down Expand Up @@ -414,6 +415,15 @@ def get_available_compute_type(self):
else:
return list(ctranslate2.get_supported_compute_types("cpu"))

def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
gc.collect()

@staticmethod
def format_time(elapsed_time: float) -> str:
"""
Expand Down

0 comments on commit 95d4f9a

Please sign in to comment.