Skip to content

Commit

Permalink
Adding option to specify custom model root (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus authored Aug 27, 2024
1 parent f6fc65e commit 4d9547d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
17 changes: 6 additions & 11 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)
os.makedirs(model_root_dir, exist_ok=True)

logging.debug("Model root directory: %s", model_root_dir)
Expand Down Expand Up @@ -270,33 +271,27 @@ def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size
self.model_root = model_root
self.progress = progress
self.total_file_size = total_file_size
self.tmp_download_root = None
self.incomplete_download_root = None
self.stop_event = threading.Event()
self.monitor_thread = None
self.set_download_roots()

def set_download_roots(self):
normalized_model_root = os.path.normpath(self.model_root)
normalized_hub_path = os.path.normpath("/models/")
index = normalized_model_root.find(normalized_hub_path)
if index > 0:
self.tmp_download_root = normalized_model_root[:index + len(normalized_hub_path)]

two_dirs_up = os.path.normpath(os.path.join(normalized_model_root, "..", ".."))
self.incomplete_download_root = os.path.normpath(os.path.join(two_dirs_up, "blobs"))

def clean_tmp_files(self):
for filename in os.listdir(self.tmp_download_root):
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
os.remove(os.path.join(self.tmp_download_root, filename))
os.remove(os.path.join(model_root_dir, filename))

def monitor_file_size(self):
while not self.stop_event.is_set():
if self.tmp_download_root is not None:
for filename in os.listdir(self.tmp_download_root):
if model_root_dir is not None:
for filename in os.listdir(model_root_dir):
if filename.startswith("tmp"):
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
file_size = os.path.getsize(os.path.join(model_root_dir, filename))
self.progress.emit((file_size, self.total_file_size))

for filename in os.listdir(self.incomplete_download_root):
Expand Down
1 change: 1 addition & 0 deletions buzz/transcriber/recording_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def start(self):
elif self.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)

device = "auto"
if platform.system() == "Windows":
Expand Down
1 change: 1 addition & 0 deletions buzz/transcriber/whisper_file_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]

model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
model_root_dir = os.getenv("BUZZ_MODEL_ROOT", model_root_dir)

device = "auto"
if platform.system() == "Windows":
Expand Down
2 changes: 2 additions & 0 deletions docs/docs/preferences.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,5 @@ combined to produce the final answer.
**BUZZ_TRANSLATION_API_BASE_URl** - Base URL of OpenAI compatible API to use for translation. Available from `v1.0.2`.

**BUZZ_TRANSLATION_API_KEY** - Api key of OpenAI compatible API to use for translation. Available from `v1.0.2`.

**BUZZ_MODEL_ROOT** - Root directory to store model files. Defaults to [user_cache_dir](https://pypi.org/project/platformdirs/). Available from `v1.0.2`.

0 comments on commit 4d9547d

Please sign in to comment.