forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into WIP
- Loading branch information
Showing
5 changed files
with
118 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import argparse | ||
import os | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from collections import OrderedDict | ||
import json | ||
# this was just used to match the vae keys to the diffusers keys | ||
# you probably wont need this. Unless they change them.... again... again | ||
# on second thought, you probably will | ||
|
||
device = torch.device('cpu') | ||
dtype = torch.float32 | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
# require at lease one config file | ||
parser.add_argument( | ||
'file_1', | ||
nargs='+', | ||
type=str, | ||
help='Path to first safe tensor file' | ||
) | ||
|
||
parser.add_argument( | ||
'file_2', | ||
nargs='+', | ||
type=str, | ||
help='Path to second safe tensor file' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
find_matches = False | ||
|
||
state_dict_file_1 = load_file(args.file_1[0]) | ||
state_dict_1_keys = list(state_dict_file_1.keys()) | ||
|
||
state_dict_file_2 = load_file(args.file_2[0]) | ||
state_dict_2_keys = list(state_dict_file_2.keys()) | ||
keys_in_both = [] | ||
|
||
keys_not_in_state_dict_2 = [] | ||
for key in state_dict_1_keys: | ||
if key not in state_dict_2_keys: | ||
keys_not_in_state_dict_2.append(key) | ||
|
||
keys_not_in_state_dict_1 = [] | ||
for key in state_dict_2_keys: | ||
if key not in state_dict_1_keys: | ||
keys_not_in_state_dict_1.append(key) | ||
|
||
keys_in_both = [] | ||
for key in state_dict_1_keys: | ||
if key in state_dict_2_keys: | ||
keys_in_both.append(key) | ||
|
||
# sort them | ||
keys_not_in_state_dict_2.sort() | ||
keys_not_in_state_dict_1.sort() | ||
keys_in_both.sort() | ||
|
||
|
||
json_data = { | ||
"both": keys_in_both, | ||
"state_dict_2": keys_not_in_state_dict_2, | ||
"state_dict_1": keys_not_in_state_dict_1 | ||
} | ||
json_data = json.dumps(json_data, indent=4) | ||
|
||
remaining_diffusers_values = OrderedDict() | ||
for key in keys_not_in_state_dict_1: | ||
remaining_diffusers_values[key] = state_dict_file_2[key] | ||
|
||
# print(remaining_diffusers_values.keys()) | ||
|
||
remaining_ldm_values = OrderedDict() | ||
for key in keys_not_in_state_dict_2: | ||
remaining_ldm_values[key] = state_dict_file_1[key] | ||
|
||
# print(json_data) | ||
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
json_save_path = os.path.join(project_root, 'config', 'keys.json') | ||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') | ||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') | ||
|
||
with open(json_save_path, 'w') as f: | ||
f.write(json_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters