diff --git a/generator_process/actions/huggingface_hub.py b/generator_process/actions/huggingface_hub.py index ebaba892..7d9c6cd1 100644 --- a/generator_process/actions/huggingface_hub.py +++ b/generator_process/actions/huggingface_hub.py @@ -1,3 +1,4 @@ +import contextlib from dataclasses import dataclass import os from pathlib import Path @@ -20,6 +21,7 @@ class ModelType(enum.IntEnum): """ Inferred model type from the U-Net `in_channels`. """ + UNKNOWN = 0 PROMPT_TO_IMAGE = 4 DEPTH = 5 @@ -31,7 +33,7 @@ class ModelType(enum.IntEnum): @classmethod def _missing_(cls, _): return cls.UNKNOWN - + def recommended_model(self) -> str: """Provides a recommended model for a given task. @@ -49,6 +51,7 @@ def recommended_model(self) -> str: case _: return "stabilityai/stable-diffusion-2-1" + @dataclass class Model: id: str @@ -58,95 +61,99 @@ class Model: downloads: int model_type: ModelType + def hf_list_models( self, query: str, token: str, ) -> list[Model]: from huggingface_hub import HfApi, ModelFilter - + if hasattr(self, "huggingface_hub_api"): api: HfApi = self.huggingface_hub_api else: api = HfApi() setattr(self, "huggingface_hub_api", api) - - filter = ModelFilter(tags="diffusers") - models = api.list_models( - filter=filter, - search=query, - use_auth_token=token - ) + + filter = ModelFilter(tags="diffusers", task="text-to-image") + models = api.list_models(filter=filter, search=query, use_auth_token=token) return [ - Model(m.modelId, m.author or "", m.tags, m.likes if hasattr(m, "likes") else 0, getattr(m, "downloads", -1), ModelType.UNKNOWN) + Model( + m.modelId, + m.author or "", + m.tags, + m.likes if hasattr(m, "likes") else 0, + getattr(m, "downloads", -1), + ModelType.UNKNOWN, + ) for m in models - if m.modelId is not None and m.tags is not None and 'diffusers' in (m.tags or {}) + if m.modelId is not None + and m.tags is not None + and "diffusers" in (m.tags or {}) ] + def hf_list_installed_models(self) -> list[Model]: from diffusers.utils import DIFFUSERS_CACHE - from diffusers.utils.hub_utils import old_diffusers_cache + + DIFFUSERS_CACHE = Path(DIFFUSERS_CACHE) + def list_dir(cache_dir): - if not os.path.exists(cache_dir): + if not cache_dir.exists(): return [] - def detect_model_type(snapshot_folder): - unet_config = os.path.join(snapshot_folder, 'unet', 'config.json') - config = os.path.join(snapshot_folder, 'config.json') - if os.path.exists(unet_config): - with open(unet_config, 'r') as f: - return ModelType(json.load(f)['in_channels']) - elif os.path.exists(config): - with open(config, 'r') as f: - config_dict = json.load(f) - if '_class_name' in config_dict and config_dict['_class_name'] == 'ControlNetModel': - return ModelType.CONTROL_NET - else: - return ModelType.UNKNOWN - else: - return ModelType.UNKNOWN - - def _map_model(file): - storage_folder = os.path.join(cache_dir, file) - model_type = ModelType.UNKNOWN - - if os.path.exists(os.path.join(storage_folder, 'model_index.json')): - snapshot_folder = storage_folder - model_type = detect_model_type(snapshot_folder) - else: - refs_folder = os.path.join(storage_folder, "refs") - if not os.path.exists(refs_folder): - return None - for revision in os.listdir(refs_folder): - ref_path = os.path.join(storage_folder, "refs", revision) - with open(ref_path) as f: - commit_hash = f.read() - snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) - if (detected_type := detect_model_type(snapshot_folder)) != ModelType.UNKNOWN: + def detect_model_type(snapshot_folder): + unet_config = os.path.join(snapshot_folder, 'unet', 'config.json') + config = os.path.join(snapshot_folder, 'config.json') + if os.path.exists(unet_config): + with open(unet_config, 'r') as f: + return ModelType(json.load(f)['in_channels']) + elif os.path.exists(config): + with open(config, 'r') as f: + config_dict = json.load(f) + if '_class_name' in config_dict and config_dict['_class_name'] == 'ControlNetModel': + return ModelType.CONTROL_NET + else: + return ModelType.UNKNOWN + else: + return ModelType.UNKNOWN + + def _map_model(file): + storage_folder = DIFFUSERS_CACHE / file + snapshot_folder = None + model_type = ModelType.UNKNOWN + + if (storage_folder / "model_index.json").exists(): + snapshot_folder = storage_folder + model_type = detect_model_type(snapshot_folder) + else: + + ref_path = storage_folder / "refs" + for revision in ref_path.iterdir(): + ref_path = storage_folder / "refs" / revision + if ref_path.exists(): + commit_hash = ref_path.read_text() + snapshot_folder = storage_folder / "snapshots" / commit_hash + if ( + detected_type := detect_model_type(snapshot_folder) + ) != ModelType.UNKNOWN: model_type = detected_type break + # else: + # print(f"Could not find {revision} in {ref_path.parent.as_posix()}") + # print(f"Candidates: {[x.stem for x in ref_path.parent.iterdir()]}") + model_type = ModelType.UNKNOWN + + if snapshot_folder: + with contextlib.suppress(Exception): + with open(snapshot_folder / "unet" / "config.json", "r") as f: + model_type = ModelType(json.load(f)["in_channels"]) + + return Model(storage_folder.as_posix(), "", [], -1, -1, model_type) + + return [_map_model(file) for file in DIFFUSERS_CACHE.iterdir() if file.is_dir()] - return Model( - storage_folder, - "", - [], - -1, - -1, - model_type - ) - return [ - model for model in ( - _map_model(file) for file in os.listdir(cache_dir) if os.path.isdir(os.path.join(cache_dir, file)) - ) - if model is not None - ] - new_cache_list = list_dir(DIFFUSERS_CACHE) - model_ids = [os.path.basename(m.id) for m in new_cache_list] - for model in list_dir(old_diffusers_cache): - if os.path.basename(model.id) not in model_ids: - new_cache_list.append(model) - return new_cache_list @dataclass class DownloadStatus: @@ -154,6 +161,7 @@ class DownloadStatus: index: int total: int + def hf_snapshot_download( self, model: str, @@ -179,7 +187,12 @@ def update(self, n=1): from huggingface_hub import _snapshot_download from diffusers import StableDiffusionPipeline - from diffusers.utils import DIFFUSERS_CACHE, WEIGHTS_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME + from diffusers.utils import ( + DIFFUSERS_CACHE, + WEIGHTS_NAME, + CONFIG_NAME, + ONNX_WEIGHTS_NAME, + ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME try: @@ -193,12 +206,12 @@ def update(self, n=1): folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, StableDiffusionPipeline.config_name] - except: + except Exception as e: + print(e) allow_patterns = None # make sure we don't download flax, safetensors, or ckpt weights. ignore_patterns = ["*.msgpack", "*.safetensors", "*.ckpt"] - try: _snapshot_download.snapshot_download( model,