diff --git a/download_models.py b/download_models.py index 7535335e..3fef7c37 100644 --- a/download_models.py +++ b/download_models.py @@ -1,8 +1,12 @@ """Contains the code to download all models specified in the config file. Executable as a standalone script.""" + from load_env_vars import load_env_vars load_env_vars() +import argparse +import time + from horde_model_reference.model_reference_manager import ModelReferenceManager from loguru import logger @@ -10,7 +14,7 @@ from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME -def download_all_models() -> None: +def download_all_models(purge_unused_loras: bool = False) -> None: """Download all models specified in the config file.""" horde_model_reference_manager = ModelReferenceManager( download_and_convert_legacy_dbs=True, @@ -45,6 +49,24 @@ def download_all_models() -> None: SharedModelManager.load_model_managers() + if purge_unused_loras: + logger.info("Purging unused LORAs...") + if SharedModelManager.manager.lora is None: + logger.error("Failed to load LORA model manager") + exit(1) + deleted_loras = SharedModelManager.manager.lora.delete_unused_loras(30) + logger.success(f"Purged {len(deleted_loras)} unused LORAs.") + + if bridge_data.allow_lora: + if SharedModelManager.manager.lora is None: + logger.error("Failed to load LORA model manager") + exit(1) + SharedModelManager.manager.lora.download_default_loras() + + while SharedModelManager.manager.lora.are_downloads_complete() is False: + logger.info("Waiting for LORA downloads to complete...") + time.sleep(8) + if bridge_data.allow_controlnet: if SharedModelManager.manager.controlnet is None: logger.error("Failed to load controlnet model manager") @@ -92,4 +114,13 @@ def download_all_models() -> None: if __name__ == "__main__": - download_all_models() + parser = argparse.ArgumentParser(description="Download all models specified in the config file.") + parser.add_argument( + "--purge-unused-loras", + action="store_true", + help="Purge unused LORAs from the cache", + ) + + args = parser.parse_args() + + download_all_models(purge_unused_loras=args.purge_unused_loras)