Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Jul 31, 2023
2 parents c01673f + c1b1e80 commit 63cacf4
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 20 deletions.
10 changes: 9 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def main():
action='store_true',
help='Continue running additional jobs even if a job fails'
)

# flag to continue if failed job
parser.add_argument(
'-n', '--name',
type=str,
default=None,
help='Name to replace [name] tag in config file, useful for shared config file'
)
args = parser.parse_args()

config_file_list = args.config_file_list
Expand All @@ -49,7 +57,7 @@ def main():

for config_file in config_file_list:
try:
job = get_job(config_file)
job = get_job(config_file, args.name)
job.run()
job.cleanup()
jobs_completed += 1
Expand Down
89 changes: 89 additions & 0 deletions testing/compare_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import os

import torch
from safetensors.torch import load_file
from collections import OrderedDict
import json
# this was just used to match the vae keys to the diffusers keys
# you probably wont need this. Unless they change them.... again... again
# on second thought, you probably will

device = torch.device('cpu')
dtype = torch.float32

parser = argparse.ArgumentParser()

# require at lease one config file
parser.add_argument(
'file_1',
nargs='+',
type=str,
help='Path to first safe tensor file'
)

parser.add_argument(
'file_2',
nargs='+',
type=str,
help='Path to second safe tensor file'
)

args = parser.parse_args()

find_matches = False

state_dict_file_1 = load_file(args.file_1[0])
state_dict_1_keys = list(state_dict_file_1.keys())

state_dict_file_2 = load_file(args.file_2[0])
state_dict_2_keys = list(state_dict_file_2.keys())
keys_in_both = []

keys_not_in_state_dict_2 = []
for key in state_dict_1_keys:
if key not in state_dict_2_keys:
keys_not_in_state_dict_2.append(key)

keys_not_in_state_dict_1 = []
for key in state_dict_2_keys:
if key not in state_dict_1_keys:
keys_not_in_state_dict_1.append(key)

keys_in_both = []
for key in state_dict_1_keys:
if key in state_dict_2_keys:
keys_in_both.append(key)

# sort them
keys_not_in_state_dict_2.sort()
keys_not_in_state_dict_1.sort()
keys_in_both.sort()


json_data = {
"both": keys_in_both,
"state_dict_2": keys_not_in_state_dict_2,
"state_dict_1": keys_not_in_state_dict_1
}
json_data = json.dumps(json_data, indent=4)

remaining_diffusers_values = OrderedDict()
for key in keys_not_in_state_dict_1:
remaining_diffusers_values[key] = state_dict_file_2[key]

# print(remaining_diffusers_values.keys())

remaining_ldm_values = OrderedDict()
for key in keys_not_in_state_dict_2:
remaining_ldm_values[key] = state_dict_file_1[key]

# print(json_data)

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_save_path = os.path.join(project_root, 'config', 'keys.json')
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')

with open(json_save_path, 'w') as f:
f.write(json_data)
15 changes: 9 additions & 6 deletions toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ def get_cwd_abs_path(path):
return path


def preprocess_config(config: OrderedDict):
def preprocess_config(config: OrderedDict, name: str = None):
if "job" not in config:
raise ValueError("config file must have a job key")
if "config" not in config:
raise ValueError("config file must have a config section")
if "name" not in config["config"]:
if "name" not in config["config"] and name is None:
raise ValueError("config file must have a config.name key")
# we need to replace tags. For now just [name]
name = config["config"]["name"]
if name is not None:
config["config"]["name"] = name
else:
name = config["config"]["name"]
config_string = json.dumps(config)
config_string = config_string.replace("[name]", name)
config = json.loads(config_string, object_pairs_hook=OrderedDict)
return config



# Fixes issue where yaml doesnt load exponents correctly
fixed_loader = yaml.SafeLoader
fixed_loader.add_implicit_resolver(
Expand All @@ -44,7 +46,8 @@ def preprocess_config(config: OrderedDict):
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))

def get_config(config_file_path):

def get_config(config_file_path, name=None):
# first check if it is in the config folder
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
# see if it is in the config folder with any of the possible extensions if it doesnt have one
Expand Down Expand Up @@ -75,4 +78,4 @@ def get_config(config_file_path):
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")

return preprocess_config(config)
return preprocess_config(config, name)
4 changes: 2 additions & 2 deletions toolkit/job.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from toolkit.config import get_config


def get_job(config_path):
config = get_config(config_path)
def get_job(config_path, name=None):
config = get_config(config_path, name)
if not config['job']:
raise ValueError('config file is invalid. Missing "job" key')

Expand Down
20 changes: 9 additions & 11 deletions toolkit/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def extract_conv(
return (extract_weight_A, extract_weight_B, diff), 'low rank'


extra_weights = ['lora_unet_conv_in.alpha', 'lora_unet_conv_in.lora_down.weight', 'lora_unet_conv_in.lora_mid.weight', 'lora_unet_conv_in.lora_up.weight', 'lora_unet_conv_out.alpha', 'lora_unet_conv_out.lora_down.weight', 'lora_unet_conv_out.lora_mid.weight', 'lora_unet_conv_out.lora_up.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_time_embedding_linear_1.alpha', 'lora_unet_time_embedding_linear_1.lora_down.weight', 'lora_unet_time_embedding_linear_1.lora_up.weight', 'lora_unet_time_embedding_linear_2.alpha', 'lora_unet_time_embedding_linear_2.lora_down.weight', 'lora_unet_time_embedding_linear_2.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_up.weight']


def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
Expand Down Expand Up @@ -177,7 +174,7 @@ def make_state_dict(
if module.__class__.__name__ in target_replace_modules:
temp[name] = {}
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
continue
temp[name][child_name] = child_module.weight
elif name in target_replace_names:
Expand All @@ -190,12 +187,12 @@ def make_state_dict(
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
layer = child_module.__class__.__name__
if layer in {'Linear', 'Conv2d'}:
if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
root_weight = child_module.weight
if torch.allclose(root_weight, weights[child_name]):
continue

if layer == 'Linear':
if layer == 'Linear' or layer == 'LoRACompatibleLinear':
weight, decompose_mode = extract_linear(
(child_module.weight - weights[child_name]),
mode,
Expand All @@ -204,7 +201,7 @@ def make_state_dict(
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
elif layer == 'Conv2d':
elif layer == 'Conv2d' or layer == 'LoRACompatibleConv':
is_linear = (child_module.weight.shape[2] == 1
and child_module.weight.shape[3] == 1)
if not is_linear and linear_only:
Expand Down Expand Up @@ -258,12 +255,12 @@ def make_state_dict(
lora_name = lora_name.replace('.', '_')
layer = module.__class__.__name__

if layer in {'Linear', 'Conv2d'}:
if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
root_weight = module.weight
if torch.allclose(root_weight, weights):
continue

if layer == 'Linear':
if layer == 'Linear' or layer == 'LoRACompatibleLinear':
weight, decompose_mode = extract_linear(
(root_weight - weights),
mode,
Expand All @@ -272,7 +269,7 @@ def make_state_dict(
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
elif layer == 'Conv2d':
elif layer == 'Conv2d' or layer == 'LoRACompatibleConv':
is_linear = (
root_weight.shape[2] == 1
and root_weight.shape[3] == 1
Expand Down Expand Up @@ -493,7 +490,8 @@ def merge_state_dict(
for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d',
'LoRACompatibleConv'}:
continue
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
Expand Down

0 comments on commit 63cacf4

Please sign in to comment.