forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Model rescale and prepared a release upgrade
- Loading branch information
1 parent
63cacf4
commit 8b8d538
Showing
15 changed files
with
387 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
--- | ||
job: mod | ||
config: | ||
name: name_of_your_model_v1 | ||
process: | ||
- type: rescale_lora | ||
# path to your current lora model | ||
input_path: "/path/to/lora/lora.safetensors" | ||
# output path for your new lora model, can be the same as input_path to replace | ||
output_path: "/path/to/lora/output_lora_v1.safetensors" | ||
# replaces meta with the meta below (plus minimum meta fields) | ||
# if false, we will leave the meta alone except for updating hashes (sd-script hashes) | ||
replace_meta: true | ||
# how to adjust, we can scale the up_down weights or the alpha | ||
# up_down is the default and probably the best, they will both net the same outputs | ||
# would only affect rare NaN cases and maybe merging with old merge tools | ||
scale_target: 'up_down' | ||
# precision to save, fp16 is the default and standard | ||
save_dtype: fp16 | ||
# current_weight is the ideal weight you use as a multiplier when using the lora | ||
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight | ||
# you can do negatives here too if you want to flip the lora | ||
current_weight: 6.0 | ||
# target_weight is the ideal weight you use as a multiplier when using the lora | ||
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0> | ||
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight | ||
target_weight: 1.0 | ||
|
||
# base model for the lora | ||
# this is just used to add meta so automatic111 knows which model it is for | ||
# assume v1.5 if these are not set | ||
is_xl: false | ||
is_v2: false | ||
meta: | ||
# this is only used if you set replace_meta to true above | ||
name: "[name]" # [name] gets replaced with the name above | ||
description: A short description of your lora | ||
trigger_words: | ||
- put | ||
- trigger | ||
- words | ||
- here | ||
version: '0.1' | ||
creator: | ||
name: Your Name | ||
email: [email protected] | ||
website: https://yourwebsite.com | ||
any: All meta data above is arbitrary, it can be whatever you want. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
from collections import OrderedDict | ||
from jobs import BaseJob | ||
from toolkit.metadata import get_meta_for_safetensors | ||
from toolkit.train_tools import get_torch_dtype | ||
|
||
process_dict = { | ||
'rescale_lora': 'ModRescaleLoraProcess', | ||
} | ||
|
||
|
||
class ModJob(BaseJob): | ||
|
||
def __init__(self, config: OrderedDict): | ||
super().__init__(config) | ||
self.device = self.get_conf('device', 'cpu') | ||
|
||
# loads the processes from the config | ||
self.load_processes(process_dict) | ||
|
||
def run(self): | ||
super().run() | ||
|
||
print("") | ||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
|
||
for process in self.process: | ||
process.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import gc | ||
import os | ||
from collections import OrderedDict | ||
from typing import ForwardRef | ||
|
||
import torch | ||
from safetensors.torch import save_file, load_file | ||
|
||
from jobs.process.BaseProcess import BaseProcess | ||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ | ||
add_base_model_info_to_meta | ||
from toolkit.train_tools import get_torch_dtype | ||
|
||
|
||
class ModRescaleLoraProcess(BaseProcess): | ||
process_id: int | ||
config: OrderedDict | ||
progress_bar: ForwardRef('tqdm') = None | ||
|
||
def __init__( | ||
self, | ||
process_id: int, | ||
job, | ||
config: OrderedDict | ||
): | ||
super().__init__(process_id, job, config) | ||
self.input_path = self.get_conf('input_path', required=True) | ||
self.output_path = self.get_conf('output_path', required=True) | ||
self.replace_meta = self.get_conf('replace_meta', default=False) | ||
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype) | ||
self.current_weight = self.get_conf('current_weight', required=True, as_type=float) | ||
self.target_weight = self.get_conf('target_weight', required=True, as_type=float) | ||
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down | ||
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool) | ||
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool) | ||
|
||
self.progress_bar = None | ||
|
||
def run(self): | ||
super().run() | ||
source_state_dict = load_file(self.input_path) | ||
source_meta = load_metadata_from_safetensors(self.input_path) | ||
|
||
if self.replace_meta: | ||
self.meta.update( | ||
add_base_model_info_to_meta( | ||
self.meta, | ||
is_xl=self.is_xl, | ||
is_v2=self.is_v2, | ||
) | ||
) | ||
save_meta = get_meta_for_safetensors(self.meta, self.job.name) | ||
else: | ||
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False) | ||
|
||
# save | ||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True) | ||
|
||
new_state_dict = OrderedDict() | ||
|
||
for key in list(source_state_dict.keys()): | ||
v = source_state_dict[key] | ||
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32')) | ||
|
||
# all loras have an alpha, up weight and down weight | ||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha", | ||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", | ||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", | ||
# we can rescale by adjusting the alpha or the up weights, or the up and down weights | ||
# I assume doing both up and down would be best all around, but I'm not sure | ||
# some locons also have mid weights, we will leave those alone for now, will work without them | ||
|
||
# when adjusting alpha, it is used to calculate the multiplier in a lora module | ||
# - scale = alpha / lora_dim | ||
# - output = layer_out + lora_up_out * multiplier * scale | ||
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \ | ||
.to("cpu", dtype=get_torch_dtype('fp32')) | ||
num_modules_layers = 2 # up and down | ||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ | ||
.to("cpu", dtype=get_torch_dtype('fp32')) | ||
# only update alpha | ||
if self.scale_target == 'alpha' and key.endswith('.alpha'): | ||
v = v * total_module_scale | ||
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): | ||
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN | ||
v = v * up_down_scale | ||
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype)) | ||
|
||
save_meta = add_model_hash_to_meta(new_state_dict, save_meta) | ||
save_file(new_state_dict, self.output_path, save_meta) | ||
|
||
# cleanup incase there are other jobs | ||
del new_state_dict | ||
del source_state_dict | ||
del source_meta | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
print(f"Saved to {self.output_path}") |
Oops, something went wrong.