From c3a746fee5ca28e534e5f20d1491fcd3e528ed97 Mon Sep 17 00:00:00 2001 From: Tessa Painter Date: Thu, 23 Nov 2023 17:45:26 -0600 Subject: [PATCH] Made the tqdm `progress_bar` objects of static download methods a static class variable --- TTS/utils/manage.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index eef987efd4..d3eb81040d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -26,7 +26,9 @@ } + class ModelManager(object): + tqdm_progress = None """Manage TTS models defined in .models.json. It provides an interface to list and download models defines in '.model.json' @@ -109,7 +111,6 @@ def _list_models(self, model_type, model_count=0): def _list_for_model_type(self, model_type): models_name_list = [] model_count = 1 - model_type = "tts_models" models_name_list.extend(self._list_models(model_type, model_count)) return models_name_list @@ -298,22 +299,22 @@ def _set_model_item(self, model_name): model_item = self.set_model_url(model_item) return model_item, model_full_name, model, md5hash - def ask_tos(self, model_full_path): + @staticmethod + def ask_tos(model_full_path): """Ask the user to agree to the terms of service""" tos_path = os.path.join(model_full_path, "tos_agreed.txt") - if not os.path.exists(tos_path): - print(" > You must agree to the terms of service to use this model.") - print(" | > Please see the terms of service at https://coqui.ai/cpml.txt") - print(' | > "I have read, understood and agreed the Terms and Conditions." - [y/n]') - answer = input(" | | > ") - if answer.lower() == "y": - with open(tos_path, "w") as f: - f.write("I have read, understood ad agree the Terms and Conditions.") - return True - else: - return False + print(" > You must agree to the terms of service to use this model.") + print(" | > Please see the terms of service at https://coqui.ai/cpml.txt") + print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]') + answer = input(" | | > ") + if answer.lower() == "y": + with open(tos_path, "w", encoding="utf-8") as f: + f.write("I have read, understood and agreed to the Terms and Conditions.") + return True + return False - def tos_agreed(self, model_item, model_full_path): + @staticmethod + def tos_agreed(model_item, model_full_path): """Check if the user has agreed to the terms of service""" if "tos_required" in model_item and model_item["tos_required"]: tos_path = os.path.join(model_full_path, "tos_agreed.txt") @@ -392,7 +393,7 @@ def download_model(self, model_name): self.create_dir_and_download_model(model_name, model_item, output_path) # if the configs are different, redownload it # ToDo: we need a better way to handle it - if "xtts_v1" in model_name: + if "xtts" in model_name: try: self.check_if_configs_are_equal(model_name, model_item, output_path) except: @@ -406,7 +407,7 @@ def download_model(self, model_name): output_model_path = output_path output_config_path = None if ( - model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name ): # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json @@ -526,12 +527,12 @@ def _download_zip_file(file_url, output_folder, progress_bar): total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) with open(temp_zip_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) with zipfile.ZipFile(temp_zip_name) as z: z.extractall(output_folder) @@ -561,12 +562,12 @@ def _download_tar_file(file_url, output_folder, progress_bar): total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) with open(temp_tar_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) with tarfile.open(temp_tar_name) as t: t.extractall(output_folder) @@ -597,10 +598,10 @@ def _download_model_files(file_urls, output_folder, progress_bar): block_size = 1024 # 1 Kibibyte with open(temp_zip_name, "wb") as file: if progress_bar: - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) for data in r.iter_content(block_size): if progress_bar: - progress_bar.update(len(data)) + ModelManager.tqdm_progress.update(len(data)) file.write(data) @staticmethod