From d64f1bd5c9fdd0259ef4136c75b30a19699a89be Mon Sep 17 00:00:00 2001 From: bkosowski Date: Sun, 15 Sep 2024 13:18:32 +0200 Subject: [PATCH] Allow loading of .sft VAE files for Flux --- backend/utils.py | 2 +- modules/sd_vae.py | 13 +++++++------ modules_forge/main_entry.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/backend/utils.py b/backend/utils.py index c88fceae2..d1ed780c0 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -22,7 +22,7 @@ def read_arbitrary_config(directory): def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") - if ckpt.lower().endswith(".safetensors"): + if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) elif ckpt.lower().endswith(".gguf"): reader = gguf.GGUFReader(ckpt) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index fdcab34dd..bd779be1a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,13 +1,10 @@ -import os import collections -from dataclasses import dataclass - -from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes - import glob +import os from copy import deepcopy -from backend.utils import load_torch_file +from dataclasses import dataclass +from modules import paths, shared, sd_models, extra_networks, hashes vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} @@ -77,9 +74,11 @@ def refresh_vae_list(): os.path.join(sd_models.model_path, '**/*.vae.ckpt'), os.path.join(sd_models.model_path, '**/*.vae.pt'), os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(sd_models.model_path, '**/*.vae.sft'), os.path.join(vae_path, '**/*.ckpt'), os.path.join(vae_path, '**/*.pt'), os.path.join(vae_path, '**/*.safetensors'), + os.path.join(vae_path, '**/*.sft'), ] if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): @@ -87,6 +86,7 @@ def refresh_vae_list(): os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.sft'), ] if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): @@ -94,6 +94,7 @@ def refresh_vae_list(): os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.sft'), ] candidates = [] diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 964f16f50..2ed94923b 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -142,7 +142,7 @@ def refresh_models(): shared_items.refresh_checkpoints() ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short) - file_extensions = ['ckpt', 'pt', 'bin', 'safetensors', 'gguf'] + vae_file_extensions = ['ckpt', 'pt', 'bin', 'safetensors', 'gguf', 'sft'] module_list.clear() @@ -157,7 +157,7 @@ def refresh_models(): module_paths.append(os.path.abspath(shared.cmd_opts.text_encoder_dir)) for vae_path in module_paths: - vae_files = find_files_with_extensions(vae_path, file_extensions) + vae_files = find_files_with_extensions(vae_path, vae_file_extensions) module_list.update(vae_files) return ckpt_list, module_list.keys()