diff --git a/README.md b/README.md index 55f02c12..25e6db17 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ pip3 install -r requirements.txt I have so many hodge podge scripts I am going to be moving over to this that I use in my ML work. But this is what is here so far. +--- + ### LoRA (lierla), LoCON (LyCORIS) extractor It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features @@ -64,6 +66,31 @@ Most people used fixed, which is traditional fixed dimension extraction. `process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc. +--- + +### LoRA Rescale + +Change `` to `` or whatever you want with the same effect. +A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it. +It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`. +Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`. +Then you can edit the file to your liking. and call it like so: + +```bash +python3 run.py config/whatever_you_want.yml +``` + +You can also put a full path to a config file, if you want to keep it somewhere else. + +```bash +python3 run.py "/home/user/whatever_you_want.yml" +``` + +More notes on how it works are available in the example config file itself. This is useful when making +all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2 +or even -15 to 15. This will allow you to dile it in so they all have your desired scale + +--- ### LoRA Slider Trainer @@ -108,13 +135,32 @@ Just went in and out. It is much worse on smaller faces than shown here. ## TODO - [X] Add proper regs on sliders -- [ ] Add SDXL support (base model only for now) +- [X] Add SDXL support (base model only for now) - [ ] Add plain erasing - [ ] Make Textual inversion network trainer (network that spits out TI embeddings) --- ## Change Log + +#### 2021-08-01 +Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so +Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable +at the moment, so hopefully there are not breaking changes. + +Unfortunately, I am too lazy to write a proper changelog with all the changes. + +I added SDXL training to sliders... but.. it does not work properly. +The slider training relies on a model's ability to understand that an unconditional (negative prompt) +means you do not want that concept in the output. SDXL does not understand this for whatever reason, +which makes separating out +concepts within the model hard. I am sure the community will find a way to fix this +over time, but for now, it is not +going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text +encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little +training without every experimental new paper added to it. The KISS principal. + + #### 2021-07-30 Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a regularizer. You can set the network multiplier to force spread consistency at high weights diff --git a/config/examples/mod_lora_scale.yaml b/config/examples/mod_lora_scale.yaml new file mode 100644 index 00000000..5f59ecc8 --- /dev/null +++ b/config/examples/mod_lora_scale.yaml @@ -0,0 +1,48 @@ +--- +job: mod +config: + name: name_of_your_model_v1 + process: + - type: rescale_lora + # path to your current lora model + input_path: "/path/to/lora/lora.safetensors" + # output path for your new lora model, can be the same as input_path to replace + output_path: "/path/to/lora/output_lora_v1.safetensors" + # replaces meta with the meta below (plus minimum meta fields) + # if false, we will leave the meta alone except for updating hashes (sd-script hashes) + replace_meta: true + # how to adjust, we can scale the up_down weights or the alpha + # up_down is the default and probably the best, they will both net the same outputs + # would only affect rare NaN cases and maybe merging with old merge tools + scale_target: 'up_down' + # precision to save, fp16 is the default and standard + save_dtype: fp16 + # current_weight is the ideal weight you use as a multiplier when using the lora + # IE in automatic1111 the 6.0 is the current_weight + # you can do negatives here too if you want to flip the lora + current_weight: 6.0 + # target_weight is the ideal weight you use as a multiplier when using the lora + # instead of the one above. IE in automatic1111 instead of using + # we want to use so 1.0 is the target_weight + target_weight: 1.0 + + # base model for the lora + # this is just used to add meta so automatic111 knows which model it is for + # assume v1.5 if these are not set + is_xl: false + is_v2: false +meta: + # this is only used if you set replace_meta to true above + name: "[name]" # [name] gets replaced with the name above + description: A short description of your lora + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. diff --git a/info.py b/info.py index 2e3c824c..85113e00 100644 --- a/info.py +++ b/info.py @@ -3,6 +3,6 @@ v = OrderedDict() v["name"] = "ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit" -v["version"] = "0.0.1" +v["version"] = "0.0.2" software_meta = v diff --git a/jobs/ModJob.py b/jobs/ModJob.py new file mode 100644 index 00000000..e37990de --- /dev/null +++ b/jobs/ModJob.py @@ -0,0 +1,28 @@ +import os +from collections import OrderedDict +from jobs import BaseJob +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + +process_dict = { + 'rescale_lora': 'ModRescaleLoraProcess', +} + + +class ModJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index 688ccfc1..f00a20dc 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -2,3 +2,4 @@ from .ExtractJob import ExtractJob from .TrainJob import TrainJob from .MergeJob import MergeJob +from .ModJob import ModJob diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8a7e891c..9a7c8302 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -19,7 +19,7 @@ DDIMScheduler, DDPMScheduler from jobs.process import BaseTrainProcess -from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta from toolkit.train_tools import get_torch_dtype, apply_noise_offset import gc @@ -192,6 +192,7 @@ def sample(self, step=None, is_first=False): num_inference_steps=sample_config.sample_steps, guidance_scale=sample_config.guidance_scale, negative_prompt=neg, + guidance_rescale=0.7, ).images[0] else: img = pipeline( @@ -236,21 +237,26 @@ def sample(self, step=None, is_first=False): # self.sd.tokenizer.to(original_device_dict['tokenizer']) def update_training_metadata(self): - dict = OrderedDict({ + o_dict = OrderedDict({ "training_info": self.get_training_info() }) if self.model_config.is_v2: - dict['ss_v2'] = True - dict['ss_base_model_version'] = 'sd_2.1' + o_dict['ss_v2'] = True + o_dict['ss_base_model_version'] = 'sd_2.1' elif self.model_config.is_xl: - dict['ss_base_model_version'] = 'sdxl_1.0' + o_dict['ss_base_model_version'] = 'sdxl_1.0' else: - dict['ss_base_model_version'] = 'sd_1.5' + o_dict['ss_base_model_version'] = 'sd_1.5' - dict['ss_output_name'] = self.job.name + o_dict = add_base_model_info_to_meta( + o_dict, + is_v2=self.model_config.is_v2, + is_xl=self.model_config.is_xl, + ) + o_dict['ss_output_name'] = self.job.name - self.add_meta(dict) + self.add_meta(o_dict) def get_training_info(self): info = OrderedDict({ @@ -381,7 +387,7 @@ def predict_noise( text_embeddings: PromptEmbeds, timestep: int, guidance_scale=7.5, - guidance_rescale=0.7, + guidance_rescale=0, # 0.7 add_time_ids=None, **kwargs, ): @@ -389,7 +395,6 @@ def predict_noise( if self.sd.is_xl: if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) - # todo LECOs code looks like it is omitting noise_pred latent_model_input = torch.cat([latents] * 2) @@ -500,13 +505,17 @@ def run(self): dtype = get_torch_dtype(self.train_config.dtype) # TODO handle other schedulers - sch = KDPM2DiscreteScheduler + # sch = KDPM2DiscreteScheduler + sch = DDPMScheduler # do our own scheduler + prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" scheduler = sch( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", + clip_sample=False, + prediction_type=prediction_type, ) if self.model_config.is_xl: if self.custom_pipeline is not None: diff --git a/jobs/process/ModRescaleLoraProcess.py b/jobs/process/ModRescaleLoraProcess.py new file mode 100644 index 00000000..882ef0e0 --- /dev/null +++ b/jobs/process/ModRescaleLoraProcess.py @@ -0,0 +1,100 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.train_tools import get_torch_dtype + + +class ModRescaleLoraProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.input_path = self.get_conf('input_path', required=True) + self.output_path = self.get_conf('output_path', required=True) + self.replace_meta = self.get_conf('replace_meta', default=False) + self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype) + self.current_weight = self.get_conf('current_weight', required=True, as_type=float) + self.target_weight = self.get_conf('target_weight', required=True, as_type=float) + self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down + self.is_xl = self.get_conf('is_xl', default=False, as_type=bool) + self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool) + + self.progress_bar = None + + def run(self): + super().run() + source_state_dict = load_file(self.input_path) + source_meta = load_metadata_from_safetensors(self.input_path) + + if self.replace_meta: + self.meta.update( + add_base_model_info_to_meta( + self.meta, + is_xl=self.is_xl, + is_v2=self.is_v2, + ) + ) + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + else: + save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + new_state_dict = OrderedDict() + + for key in list(source_state_dict.keys()): + v = source_state_dict[key] + v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32')) + + # all loras have an alpha, up weight and down weight + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", + # we can rescale by adjusting the alpha or the up weights, or the up and down weights + # I assume doing both up and down would be best all around, but I'm not sure + # some locons also have mid weights, we will leave those alone for now, will work without them + + # when adjusting alpha, it is used to calculate the multiplier in a lora module + # - scale = alpha / lora_dim + # - output = layer_out + lora_up_out * multiplier * scale + total_module_scale = torch.tensor(self.current_weight / self.target_weight) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + # only update alpha + if self.scale_target == 'alpha' and key.endswith('.alpha'): + v = v * total_module_scale + if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN + v = v * up_down_scale + new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype)) + + save_meta = add_model_hash_to_meta(new_state_dict, save_meta) + save_file(new_state_dict, self.output_path, save_meta) + + # cleanup incase there are other jobs + del new_state_dict + del source_state_dict + del source_meta + + torch.cuda.empty_cache() + gc.collect() + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 0913eebf..e618d9c9 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -46,8 +46,8 @@ def __init__( negative_target, negative_target_with_neutral, neutral, - both_targets, empty_prompt, + both_targets, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, multiplier=1.0, weight=1.0 @@ -123,22 +123,23 @@ def hook_before_train_loop(self): self.print(f"Loading prompt file from {self.slider_config.prompt_file}") # read line by line from file - with open(self.slider_config.prompt_file, 'r') as f: - self.prompt_txt_list = f.readlines() - # clean empty lines - self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + if self.slider_config.prompt_file: + with open(self.slider_config.prompt_file, 'r') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] - self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") + self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") - cache = PromptEmbedsCache() - if not self.slider_config.prompt_tensors: - # shuffle - random.shuffle(self.prompt_txt_list) - # trim to max steps - self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] - # trim list to our max steps + if not self.slider_config.prompt_tensors: + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + cache = PromptEmbedsCache() # get encoded latents for our prompts with torch.no_grad(): @@ -169,7 +170,9 @@ def hook_before_train_loop(self): # encode empty_prompt cache[empty_prompt] = self.sd.encode_prompt(empty_prompt) - for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): for target in self.slider_config.targets: prompt_list = [ f"{target.target_class}", # target_class @@ -212,10 +215,15 @@ def hook_before_train_loop(self): save_file(state_dict, self.slider_config.prompt_tensors) prompt_pairs = [] - for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): + for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): for target in self.slider_config.targets: + erase_negative = len(target.positive.strip()) == 0 + enhance_positive = len(target.negative.strip()) == 0 + + both = not erase_negative and not enhance_positive if both or erase_negative: + print("Encoding erase negative") prompt_pairs += [ # erase standard EncodedPromptPair( @@ -234,6 +242,7 @@ def hook_before_train_loop(self): ), ] if both or enhance_positive: + print("Encoding enhance positive") prompt_pairs += [ # enhance standard, swap pos neg EncodedPromptPair( @@ -251,7 +260,9 @@ def hook_before_train_loop(self): weight=target.weight ), ] - if both or enhance_positive: + # if both or enhance_positive: + if both: + print("Encoding erase positive (inverse)") prompt_pairs += [ # erase inverted EncodedPromptPair( @@ -269,7 +280,9 @@ def hook_before_train_loop(self): weight=target.weight ), ] - if both or erase_negative: + # if both or erase_negative: + if both: + print("Encoding enhance negative (inverse)") prompt_pairs += [ # enhance inverted EncodedPromptPair( @@ -341,10 +354,6 @@ def hook_train_loop(self): torch.randint(0, len(self.slider_config.resolutions), (1,)).item() ] - target_class = prompt_pair.target_class - neutral = prompt_pair.neutral - negative = prompt_pair.negative_target - positive = prompt_pair.positive_target weight = prompt_pair.weight multiplier = prompt_pair.multiplier diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index e4fc21bc..e58e0069 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -8,4 +8,5 @@ from .TrainSliderProcess import TrainSliderProcess from .TrainSliderProcessOld import TrainSliderProcessOld from .TrainLoRAHack import TrainLoRAHack -from .TrainSDRescaleProcess import TrainSDRescaleProcess \ No newline at end of file +from .TrainSDRescaleProcess import TrainSDRescaleProcess +from .ModRescaleLoraProcess import ModRescaleLoraProcess diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 03bf3487..ef8f3c95 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -99,5 +99,5 @@ def __init__(self, **kwargs): anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] self.anchors: List[SliderConfigAnchors] = anchors self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) - self.prompt_file: str = kwargs.get('prompt_file', '') - self.prompt_tensors: str = kwargs.get('prompt_tensors', '') + self.prompt_file: str = kwargs.get('prompt_file', None) + self.prompt_tensors: str = kwargs.get('prompt_tensors', None) diff --git a/toolkit/job.py b/toolkit/job.py index da85505b..c0b1a191 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -13,6 +13,9 @@ def get_job(config_path, name=None): if job == 'train': from jobs import TrainJob return TrainJob(config) + if job == 'mod': + from jobs import ModJob + return ModJob(config) # elif job == 'train': # from jobs import TrainJob diff --git a/toolkit/lora.py b/toolkit/lora.py index 9b3b65a6..0780cbc3 100644 --- a/toolkit/lora.py +++ b/toolkit/lora.py @@ -6,12 +6,14 @@ import os import math from typing import Optional, List, Type, Set, Literal +from collections import OrderedDict import torch import torch.nn as nn from diffusers import UNet2DConditionModel from safetensors.torch import save_file +from toolkit.metadata import add_model_hash_to_meta UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 @@ -31,7 +33,7 @@ "innoxattn", # train all layers except self attention layers "selfattn", # ESD-u, train only self attention layers "xattn", # ESD-x, train only x attention layers - "full", # train all layers + "full", # train all layers # "notime", # "xlayer", # "outxattn", @@ -48,12 +50,12 @@ class LoRAModule(nn.Module): """ def __init__( - self, - lora_name, - org_module: nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -102,19 +104,19 @@ def apply_to(self): def forward(self, x): return ( - self.org_forward(x) - + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + self.org_forward(x) + + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale ) class LoRANetwork(nn.Module): def __init__( - self, - unet: UNet2DConditionModel, - rank: int = 4, - multiplier: float = 1.0, - alpha: float = 1.0, - train_method: TRAINING_METHODS = "full", + self, + unet: UNet2DConditionModel, + rank: int = 4, + multiplier: float = 1.0, + alpha: float = 1.0, + train_method: TRAINING_METHODS = "full", ) -> None: super().__init__() @@ -140,7 +142,7 @@ def __init__( lora_names = set() for lora in self.unet_loras: assert ( - lora.lora_name not in lora_names + lora.lora_name not in lora_names ), f"duplicated lora name: {lora.lora_name}. {lora_names}" lora_names.add(lora.lora_name) @@ -157,13 +159,13 @@ def __init__( torch.cuda.empty_cache() def create_modules( - self, - prefix: str, - root_module: nn.Module, - target_replace_modules: List[str], - rank: int, - multiplier: float, - train_method: TRAINING_METHODS, + self, + prefix: str, + root_module: nn.Module, + target_replace_modules: List[str], + rank: int, + multiplier: float, + train_method: TRAINING_METHODS, ) -> list: loras = [] @@ -212,6 +214,8 @@ def prepare_optimizer_params(self): def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): state_dict = self.state_dict() + if metadata is None: + metadata = OrderedDict() if dtype is not None: for key in list(state_dict.keys()): @@ -221,9 +225,10 @@ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): for key in list(state_dict.keys()): if not key.startswith("lora"): - # lora以外除外 + # remove any not lora del state_dict[key] + metadata = add_model_hash_to_meta(state_dict, metadata) if os.path.splitext(file)[1] == ".safetensors": save_file(state_dict, file, metadata) else: diff --git a/toolkit/metadata.py b/toolkit/metadata.py index 6605feb3..b652dcef 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -1,18 +1,23 @@ import json from collections import OrderedDict +from io import BytesIO +import safetensors from safetensors import safe_open from info import software_meta +from toolkit.train_tools import addnet_hash_legacy +from toolkit.train_tools import addnet_hash_safetensors -def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict: +def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict: # stringify the meta and reparse OrderedDict to replace [name] with name meta_string = json.dumps(meta) if name is not None: meta_string = meta_string.replace("[name]", name) save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict) - save_meta["software"] = software_meta + if add_software_info: + save_meta["software"] = software_meta # safetensors can only be one level deep for key, value in save_meta.items(): # if not float, int, bool, or str, convert to json string @@ -21,6 +26,46 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict: return save_meta +def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict: + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in meta.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(state_dict, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + meta["sshs_model_hash"] = model_hash + meta["sshs_legacy_hash"] = legacy_hash + return meta + + +def add_base_model_info_to_meta( + meta: OrderedDict, + base_model: str = None, + is_v1: bool = False, + is_v2: bool = False, + is_xl: bool = False, +) -> OrderedDict: + if base_model is not None: + meta['ss_base_model'] = base_model + elif is_v2: + meta['ss_v2'] = True + meta['ss_base_model_version'] = 'sd_2.1' + + elif is_xl: + meta['ss_base_model_version'] = 'sdxl_1.0' + else: + # default to v1.5 + meta['ss_base_model_version'] = 'sd_1.5' + return meta + + def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: parsed_meta = OrderedDict() for key, value in meta.items(): diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 09ebb734..96aabfca 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -54,6 +54,8 @@ def get_optimizer( elif lower_type == 'lion': from lion_pytorch import Lion return Lion(params, lr=learning_rate, **optimizer_params) + elif lower_type == 'adagrad': + optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 9ac9f31c..3be84dca 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -1,4 +1,5 @@ import argparse +import hashlib import json import os import time @@ -399,3 +400,29 @@ def concat_prompt_embeddings( [unconditional.pooled_embeds, conditional.pooled_embeds] ).repeat_interleave(n_imgs, dim=0) return PromptEmbeds([text_embeds, pooled_embeds]) + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8]