diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 1cd437e611..d3eb81040d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -26,7 +26,9 @@ } + class ModelManager(object): + tqdm_progress = None """Manage TTS models defined in .models.json. It provides an interface to list and download models defines in '.model.json' @@ -525,12 +527,12 @@ def _download_zip_file(file_url, output_folder, progress_bar): total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) with open(temp_zip_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) with zipfile.ZipFile(temp_zip_name) as z: z.extractall(output_folder) @@ -560,12 +562,12 @@ def _download_tar_file(file_url, output_folder, progress_bar): total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) with open(temp_tar_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) with tarfile.open(temp_tar_name) as t: t.extractall(output_folder) @@ -596,10 +598,10 @@ def _download_model_files(file_urls, output_folder, progress_bar): block_size = 1024 # 1 Kibibyte with open(temp_zip_name, "wb") as file: if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) @staticmethod