diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 592e1bf7..77f95288 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,6 @@ import os import torch -from typing import List, Union, BinaryIO +from typing import List, Union, BinaryIO, Optional import numpy as np import time import logging @@ -24,7 +24,7 @@ def run(self, audio: Union[str, BinaryIO, np.ndarray], transcribed_result: List[dict], use_auth_token: str, - device: str + device: Optional[str] = None ): """ Diarize transcribed result as a post-processing @@ -38,7 +38,7 @@ def run(self, use_auth_token: str Huggingface token with READ permission. This is only needed the first time you download the model. You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. - device: str + device: Optional[str] Device for diarization. Returns @@ -50,8 +50,10 @@ def run(self, """ start_time = time.time() - if (device != self.device - or self.pipe is None): + if device is None: + device = self.device + + if device != self.device or self.pipe is None: self.update_pipe( device=device, use_auth_token=use_auth_token @@ -89,6 +91,7 @@ def update_pipe(self, device: str Device for diarization. """ + self.device = device os.makedirs(self.model_dir, exist_ok=True) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index d312c753..c606398f 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -130,7 +130,6 @@ def run(self, audio=audio, use_auth_token=params.hf_token, transcribed_result=result, - device=self.device ) elapsed_time += elapsed_time_diarization return result, elapsed_time