diff --git a/run.py b/run.py index 15c57a2a..e6a250be 100644 --- a/run.py +++ b/run.py @@ -36,6 +36,14 @@ def main(): action='store_true', help='Continue running additional jobs even if a job fails' ) + + # flag to continue if failed job + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) args = parser.parse_args() config_file_list = args.config_file_list @@ -49,7 +57,7 @@ def main(): for config_file in config_file_list: try: - job = get_job(config_file) + job = get_job(config_file, args.name) job.run() job.cleanup() jobs_completed += 1 diff --git a/testing/compare_keys.py b/testing/compare_keys.py new file mode 100644 index 00000000..021178bd --- /dev/null +++ b/testing/compare_keys.py @@ -0,0 +1,89 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file +from collections import OrderedDict +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument( + 'file_2', + nargs='+', + type=str, + help='Path to second safe tensor file' +) + +args = parser.parse_args() + +find_matches = False + +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +state_dict_file_2 = load_file(args.file_2[0]) +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + + +json_data = { + "both": keys_in_both, + "state_dict_2": keys_not_in_state_dict_2, + "state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') + +with open(json_save_path, 'w') as f: + f.write(json_data) \ No newline at end of file diff --git a/toolkit/config.py b/toolkit/config.py index 9d51c3be..602e04ae 100644 --- a/toolkit/config.py +++ b/toolkit/config.py @@ -15,22 +15,24 @@ def get_cwd_abs_path(path): return path -def preprocess_config(config: OrderedDict): +def preprocess_config(config: OrderedDict, name: str = None): if "job" not in config: raise ValueError("config file must have a job key") if "config" not in config: raise ValueError("config file must have a config section") - if "name" not in config["config"]: + if "name" not in config["config"] and name is None: raise ValueError("config file must have a config.name key") # we need to replace tags. For now just [name] - name = config["config"]["name"] + if name is not None: + config["config"]["name"] = name + else: + name = config["config"]["name"] config_string = json.dumps(config) config_string = config_string.replace("[name]", name) config = json.loads(config_string, object_pairs_hook=OrderedDict) return config - # Fixes issue where yaml doesnt load exponents correctly fixed_loader = yaml.SafeLoader fixed_loader.add_implicit_resolver( @@ -44,7 +46,8 @@ def preprocess_config(config: OrderedDict): |\\.(?:nan|NaN|NAN))$''', re.X), list(u'-+0123456789.')) -def get_config(config_file_path): + +def get_config(config_file_path, name=None): # first check if it is in the config folder config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) # see if it is in the config folder with any of the possible extensions if it doesnt have one @@ -75,4 +78,4 @@ def get_config(config_file_path): else: raise ValueError(f"Config file {config_file_path} must be a json or yaml file") - return preprocess_config(config) + return preprocess_config(config, name) diff --git a/toolkit/job.py b/toolkit/job.py index 62d1bcfd..da85505b 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -1,8 +1,8 @@ from toolkit.config import get_config -def get_job(config_path): - config = get_config(config_path) +def get_job(config_path, name=None): + config = get_config(config_path, name) if not config['job']: raise ValueError('config file is invalid. Missing "job" key') diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py index dad5aff8..af11ee9e 100644 --- a/toolkit/lycoris_utils.py +++ b/toolkit/lycoris_utils.py @@ -67,9 +67,6 @@ def extract_conv( return (extract_weight_A, extract_weight_B, diff), 'low rank' -extra_weights = ['lora_unet_conv_in.alpha', 'lora_unet_conv_in.lora_down.weight', 'lora_unet_conv_in.lora_mid.weight', 'lora_unet_conv_in.lora_up.weight', 'lora_unet_conv_out.alpha', 'lora_unet_conv_out.lora_down.weight', 'lora_unet_conv_out.lora_mid.weight', 'lora_unet_conv_out.lora_up.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_time_embedding_linear_1.alpha', 'lora_unet_time_embedding_linear_1.lora_down.weight', 'lora_unet_time_embedding_linear_1.lora_up.weight', 'lora_unet_time_embedding_linear_2.alpha', 'lora_unet_time_embedding_linear_2.lora_down.weight', 'lora_unet_time_embedding_linear_2.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_up.weight'] - - def extract_linear( weight: Union[torch.Tensor, nn.Parameter], mode='fixed', @@ -177,7 +174,7 @@ def make_state_dict( if module.__class__.__name__ in target_replace_modules: temp[name] = {} for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: continue temp[name][child_name] = child_module.weight elif name in target_replace_names: @@ -190,12 +187,12 @@ def make_state_dict( lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') layer = child_module.__class__.__name__ - if layer in {'Linear', 'Conv2d'}: + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: root_weight = child_module.weight if torch.allclose(root_weight, weights[child_name]): continue - if layer == 'Linear': + if layer == 'Linear' or layer == 'LoRACompatibleLinear': weight, decompose_mode = extract_linear( (child_module.weight - weights[child_name]), mode, @@ -204,7 +201,7 @@ def make_state_dict( ) if decompose_mode == 'low rank': extract_a, extract_b, diff = weight - elif layer == 'Conv2d': + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': is_linear = (child_module.weight.shape[2] == 1 and child_module.weight.shape[3] == 1) if not is_linear and linear_only: @@ -258,12 +255,12 @@ def make_state_dict( lora_name = lora_name.replace('.', '_') layer = module.__class__.__name__ - if layer in {'Linear', 'Conv2d'}: + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: root_weight = module.weight if torch.allclose(root_weight, weights): continue - if layer == 'Linear': + if layer == 'Linear' or layer == 'LoRACompatibleLinear': weight, decompose_mode = extract_linear( (root_weight - weights), mode, @@ -272,7 +269,7 @@ def make_state_dict( ) if decompose_mode == 'low rank': extract_a, extract_b, diff = weight - elif layer == 'Conv2d': + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': is_linear = ( root_weight.shape[2] == 1 and root_weight.shape[3] == 1 @@ -493,7 +490,8 @@ def merge_state_dict( for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', + 'LoRACompatibleConv'}: continue lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_')