From 8c90fa86c68b7f67a7853737fccb0ff6c491e689 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 5 Aug 2023 18:46:08 -0600 Subject: [PATCH] Complete reqork of how slider training works and optimized it to hell. Can run entire algorythm in 1 batch now with less VRAM consumption than a quarter of it used to take --- README.md | 15 +- config/examples/train_slider.example.yml | 104 +++-- info.py | 2 +- jobs/process/BaseSDTrainProcess.py | 9 + jobs/process/TrainSliderProcess.py | 525 ++++++++++------------- toolkit/config_modules.py | 1 + toolkit/layers.py | 13 + toolkit/lora_special.py | 203 ++++++++- toolkit/prompt_utils.py | 387 +++++++++++++++++ toolkit/stable_diffusion_model.py | 60 ++- 10 files changed, 942 insertions(+), 377 deletions(-) create mode 100644 toolkit/prompt_utils.py diff --git a/README.md b/README.md index 0885e733..9ed568a3 100644 --- a/README.md +++ b/README.md @@ -170,18 +170,27 @@ Just went in and out. It is much worse on smaller faces than shown here. ## Change Log +#### 2023-08-05 + - Huge memory rework and slider rework. Slider training is better thant ever with no more +ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient +accumulation. This makes it much faster and more stable. + - Updated the example config to be something more practical and more updated to current methods. It is now +a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on +6GB gpu now. Will test soon to verify. + + #### 2021-10-20 - Windows support bug fixes - Extensions! Added functionality to make and share custom extensions for training, merging, whatever. check out the example in the `extensions` folder. Read more about that above. - Model Merging, provided via the example extension. -#### 2021-08-03 +#### 2023-08-03 Another big refactor to make SD more modular. Made batch image generation script -#### 2021-08-01 +#### 2023-08-01 Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable at the moment, so hopefully there are not breaking changes. @@ -199,7 +208,7 @@ encoders to the model as well as a few more entirely separate diffusion networks training without every experimental new paper added to it. The KISS principal. -#### 2021-07-30 +#### 2023-07-30 Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a regularizer. You can set the network multiplier to force spread consistency at high weights diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index be796a09..c92955ec 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -7,7 +7,7 @@ job: train config: # the name will be used to create a folder in the output folder # it will also replace any [name] token in the rest of this config - name: pet_slider_v1 + name: detail_slider_v1 # folder will be created with name above in folder below # it can be relative to the project root or absolute training_folder: "output/LoRA" @@ -24,7 +24,7 @@ config: type: "lierla" # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good rank: 8 - alpha: 1.0 # just leave it + alpha: 4 # Do about half of rank # training config train: @@ -33,7 +33,7 @@ config: # how many steps to train. More is not always better. I rarely go over 1000 steps: 500 # I have had good results with 4e-4 to 1e-4 at 500 steps - lr: 1e-4 + lr: 2e-4 # enables gradient checkpoint, saves vram, leave it on gradient_checkpointing: true # train the unet. I recommend leaving this true @@ -43,6 +43,7 @@ config: # not the description of it (text encoder) train_text_encoder: false + # just leave unless you know what you are doing # also supports "dadaptation" but set lr to 1 if you use that, # but it learns too fast and I don't recommend it @@ -53,6 +54,7 @@ config: # while training. Just leave it max_denoising_steps: 40 # works great at 1. I do 1 even with my 4090. + # higher may not work right with newer single batch stacking code anyway batch_size: 1 # bf16 works best if your GPU supports it (modern) dtype: bf16 # fp32, bf16, fp16 @@ -69,12 +71,17 @@ config: name_or_path: "runwayml/stable-diffusion-v1-5" is_v2: false # for v2 models is_v_pred: false # for v-prediction models (most v2 models) + # has some issues with the dual text encoder and the way we train sliders + # it works bit weights need to probably be higher to see it. is_xl: false # for SDXL models # saving config save: dtype: float16 # precision to save. I recommend float16 save_every: 50 # save every this many steps + # this will remove step counts more than this number + # allows you to save more often in case of a crash without filling up your drive + max_step_saves_to_keep: 2 # sampling config sample: @@ -92,21 +99,22 @@ config: # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive # slide are good tests. will inherit sample.network_multiplier if not set # --n [string] # negative prompt, will inherit sample.neg if not set - # Only 75 tokens allowed currently - prompts: # our example is an animal slider, neg: dog, pos: cat - - "a golden retriever --m -5" - - "a golden retriever --m -3" - - "a golden retriever --m 3" - - "a golden retriever --m 5" - - "calico cat --m -5" - - "calico cat --m -3" - - "calico cat --m 3" - - "calico cat --m 5" - - "an elephant --m -5" - - "an elephant --m -3" - - "an elephant --m 3" - - "an elephant --m 5" + # I like to do a wide positive and negative spread so I can see a good range and stop + # early if the network is braking down + prompts: + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5" + - "a golden retriever sitting on a leather couch, --m -5" + - "a golden retriever sitting on a leather couch --m -3" + - "a golden retriever sitting on a leather couch --m 3" + - "a golden retriever sitting on a leather couch --m 5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5" # negative prompt used on all prompts above as default if they don't have one neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome" # seed for sampling. 42 is the answer for everything @@ -135,11 +143,16 @@ config: # resolutions to train on. [ width, height ]. This is less important for sliders # as we are not teaching the model anything it doesn't already know # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1 + # and [ 1024, 1024 ] for sd_xl # you can do as many as you want here resolutions: - [ 512, 512 ] # - [ 512, 768 ] # - [ 768, 768 ] + # slider training uses 4 combined steps for a single round. This will do it in one gradient + # step. It is highly optimized and shouldn't take anymore vram than doing without it, + # since we break down batches for gradient accumulation now. so just leave it on. + batch_full_slide: true # These are the concepts to train on. You can do as many as you want here, # but they can conflict outweigh each other. Other than experimenting, I recommend # just doing one for good results @@ -150,7 +163,9 @@ config: # a keyword necessarily but what the model understands the concept to represent. # "person" will affect men, women, children, etc but will not affect cats, dogs, etc # it is the models base general understanding of the concept and everything it represents - - target_class: "animal" + # you can leave it blank to affect everything. In this example, we are adjusting + # detail, so we will leave it blank to affect everything + - target_class: "" # positive is the prompt for the positive side of the slider. # It is the concept that will be excited and amplified in the model when we slide the slider # to the positive side and forgotten / inverted when we slide @@ -158,33 +173,44 @@ config: # the prompt. You want it to be the extreme of what you want to train on. For example, # if you want to train on fat people, you would use "an extremely fat, morbidly obese person" # as the prompt. Not just "fat person" - positive: "cat" + # max 75 tokens for now + positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # negative is the prompt for the negative side of the slider and works the same as positive # it does not necessarily work the same as a negative prompt when generating images - negative: "dog" + # these need to be polar opposites. + # max 76 tokens for now + negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # the loss for this target is multiplied by this number. # if you are doing more than one target it may be good to set less important ones - # to a lower number like 0.1 so they dont outweigh the primary target + # to a lower number like 0.1 so they don't outweigh the primary target weight: 1.0 - # anchors are prompts that wer try to hold on to while training the slider - # you want these to generate an image very similar to the target_class - # without directly overlapping it. For example, if you are training on a person smiling, - # you would use "a person with a face mask" as an anchor. It is a person, the image is the same - # regardless if they are smiling or not - anchors: - # only positive prompt for now - - prompt: "a woman" - neg_prompt: "animal" - # the multiplier applied to the LoRA when this is run. - # higher will give it more weight but also help keep the lora from collapsing - multiplier: 8.0 - - prompt: "a man" - neg_prompt: "animal" - multiplier: 8.0 - - prompt: "a person" - neg_prompt: "animal" - multiplier: 8.0 + + # anchors are prompts that we will try to hold on to while training the slider + # these are NOT necessary and can prevent the slider from converging if not done right + # leave them off if you are having issues, but they can help lock the network + # on certain concepts to help prevent catastrophic forgetting + # you want these to generate an image that is not your target_class, but close to it + # is fine as long as it does not directly overlap it. + # For example, if you are training on a person smiling, + # you could use "a person with a face mask" as an anchor. It is a person, the image is the same + # regardless if they are smiling or not, however, the closer the concept is to the target_class + # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually + # for close concepts, you want to be closer to 0.1 or 0.2 + # these will slow down training. I am leaving them off for the demo + +# anchors: +# - prompt: "a woman" +# neg_prompt: "animal" +# # the multiplier applied to the LoRA when this is run. +# # higher will give it more weight but also help keep the lora from collapsing +# multiplier: 1.0 +# - prompt: "a man" +# neg_prompt: "animal" +# multiplier: 1.0 +# - prompt: "a person" +# neg_prompt: "animal" +# multiplier: 1.0 # You can put any information you want here, and it will be saved in the model. # The below is an example, but you can put your grocery list in it if you want. diff --git a/info.py b/info.py index b40d9344..81cffcad 100644 --- a/info.py +++ b/info.py @@ -3,6 +3,6 @@ v = OrderedDict() v["name"] = "ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit" -v["version"] = "0.0.3" +v["version"] = "0.0.4" software_meta = v diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index afde7db7..58bfba02 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -242,6 +242,12 @@ def run(self): unet.enable_xformers_memory_efficient_attention() if self.train_config.gradient_checkpointing: unet.enable_gradient_checkpointing() + # if isinstance(text_encoder, list): + # for te in text_encoder: + # te.enable_gradient_checkpointing() + # else: + # text_encoder.enable_gradient_checkpointing() + unet.to(self.device_torch, dtype=dtype) unet.requires_grad_(False) unet.eval() @@ -281,6 +287,9 @@ def run(self): default_lr=self.train_config.lr ) + if self.train_config.gradient_checkpointing: + self.network.enable_gradient_checkpointing() + latest_save_path = self.get_latest_save_path() if latest_save_path is not None: self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index ecc99542..e10ec5b2 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -3,12 +3,14 @@ import random from collections import OrderedDict import os -from typing import Optional +from typing import Optional, Union from safetensors.torch import save_file, load_file +import torch.utils.checkpoint as cp from tqdm import tqdm from toolkit.config_modules import SliderConfig +from toolkit.layers import CheckpointGradients from toolkit.paths import REPOS_ROOT import sys @@ -16,88 +18,21 @@ from toolkit.train_tools import get_torch_dtype import gc from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs import torch from .BaseSDTrainProcess import BaseSDTrainProcess -class ACTION_TYPES_SLIDER: - ERASE_NEGATIVE = 0 - ENHANCE_NEGATIVE = 1 - - def flush(): torch.cuda.empty_cache() gc.collect() -class EncodedPromptPair: - def __init__( - self, - target_class, - target_class_with_neutral, - positive_target, - positive_target_with_neutral, - negative_target, - negative_target_with_neutral, - neutral, - empty_prompt, - both_targets, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=1.0, - weight=1.0 - ): - self.target_class = target_class - self.target_class_with_neutral = target_class_with_neutral - self.positive_target = positive_target - self.positive_target_with_neutral = positive_target_with_neutral - self.negative_target = negative_target - self.negative_target_with_neutral = negative_target_with_neutral - self.neutral = neutral - self.empty_prompt = empty_prompt - self.both_targets = both_targets - self.multiplier = multiplier - self.action: int = action - self.weight = weight - - # simulate torch to for tensors - def to(self, *args, **kwargs): - self.target_class = self.target_class.to(*args, **kwargs) - self.positive_target = self.positive_target.to(*args, **kwargs) - self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) - self.negative_target = self.negative_target.to(*args, **kwargs) - self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) - self.neutral = self.neutral.to(*args, **kwargs) - self.empty_prompt = self.empty_prompt.to(*args, **kwargs) - self.both_targets = self.both_targets.to(*args, **kwargs) - return self - - -class PromptEmbedsCache: - prompts: dict[str, PromptEmbeds] = {} - - def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: - self.prompts[__name] = __value - - def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: - if __name in self.prompts: - return self.prompts[__name] - else: - return None - - -class EncodedAnchor: - def __init__( - self, - prompt, - neg_prompt, - multiplier=1.0 - ): - self.prompt = prompt - self.neg_prompt = neg_prompt - self.multiplier = multiplier - - class TrainSliderProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -110,6 +45,8 @@ def __init__(self, process_id: int, job, config: OrderedDict): self.prompt_cache = PromptEmbedsCache() self.prompt_pairs: list[EncodedPromptPair] = [] self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 def before_model_load(self): pass @@ -137,163 +74,57 @@ def hook_before_train_loop(self): # get encoded latents for our prompts with torch.no_grad(): - if self.slider_config.prompt_tensors is not None: - # check to see if it exists - if os.path.exists(self.slider_config.prompt_tensors): - # load it. - self.print(f"Loading prompt tensors from {self.slider_config.prompt_tensors}") - prompt_tensors = load_file(self.slider_config.prompt_tensors, device='cpu') - # add them to the cache - for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): - if prompt_txt.startswith("te:"): - prompt = prompt_txt[3:] - # text_embeds - text_embeds = prompt_tensor - pooled_embeds = None - # find pool embeds - if f"pe:{prompt}" in prompt_tensors: - pooled_embeds = prompt_tensors[f"pe:{prompt}"] - - # make it - prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) - cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) - - if len(cache.prompts) == 0: - print("Prompt tensors not found. Encoding prompts..") - empty_prompt = "" - # encode empty_prompt - cache[empty_prompt] = self.sd.encode_prompt(empty_prompt) - - neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] - - for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): - for target in self.slider_config.targets: - prompt_list = [ - f"{target.target_class}", # target_class - f"{target.target_class} {neutral}", # target_class with neutral - f"{target.positive}", # positive_target - f"{target.positive} {neutral}", # positive_target with neutral - f"{target.negative}", # negative_target - f"{target.negative} {neutral}", # negative_target with neutral - f"{neutral}", # neutral - f"{target.positive} {target.negative}", # both targets - f"{target.negative} {target.positive}", # both targets - ] - for p in prompt_list: - # build the cache - if cache[p] is None: - cache[p] = self.sd.encode_prompt(p).to(device="cpu", dtype=torch.float32) - - erase_negative = len(target.positive.strip()) == 0 - enhance_positive = len(target.negative.strip()) == 0 - - both = not erase_negative and not enhance_positive - - if erase_negative and enhance_positive: - raise ValueError("target must have at least one of positive or negative or both") - # for slider we need to have an enhancer, an eraser, and then - # an inverse with negative weights to balance the network - # if we don't do this, we will get different contrast and focus. - # we only perform actions of enhancing and erasing on the negative - # todo work on way to do all of this in one shot - if self.slider_config.prompt_tensors: - print(f"Saving prompt tensors to {self.slider_config.prompt_tensors}") - state_dict = {} - for prompt_txt, prompt_embeds in cache.prompts.items(): - state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", - dtype=get_torch_dtype('fp16')) - if prompt_embeds.pooled_embeds is not None: - state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", - dtype=get_torch_dtype( - 'fp16')) - save_file(state_dict, self.slider_config.prompt_tensors) + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) prompt_pairs = [] - for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): for target in self.slider_config.targets: - erase_negative = len(target.positive.strip()) == 0 - enhance_positive = len(target.negative.strip()) == 0 - - both = not erase_negative and not enhance_positive - - if both or erase_negative: - print("Encoding erase negative") - prompt_pairs += [ - # erase standard - EncodedPromptPair( - target_class=cache[target.target_class], - target_class_with_neutral=cache[f"{target.target_class} {neutral}"], - positive_target=cache[f"{target.positive}"], - positive_target_with_neutral=cache[f"{target.positive} {neutral}"], - negative_target=cache[f"{target.negative}"], - negative_target_with_neutral=cache[f"{target.negative} {neutral}"], - neutral=cache[neutral], - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier, - both_targets=cache[f"{target.positive} {target.negative}"], - empty_prompt=cache[""], - weight=target.weight - ), - ] - if both or enhance_positive: - print("Encoding enhance positive") - prompt_pairs += [ - # enhance standard, swap pos neg - EncodedPromptPair( - target_class=cache[target.target_class], - target_class_with_neutral=cache[f"{target.target_class} {neutral}"], - positive_target=cache[f"{target.negative}"], - positive_target_with_neutral=cache[f"{target.negative} {neutral}"], - negative_target=cache[f"{target.positive}"], - negative_target_with_neutral=cache[f"{target.positive} {neutral}"], - neutral=cache[neutral], - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier, - both_targets=cache[f"{target.positive} {target.negative}"], - empty_prompt=cache[""], - weight=target.weight - ), - ] - # if both or enhance_positive: - if both: - print("Encoding erase positive (inverse)") - prompt_pairs += [ - # erase inverted - EncodedPromptPair( - target_class=cache[target.target_class], - target_class_with_neutral=cache[f"{target.target_class} {neutral}"], - positive_target=cache[f"{target.negative}"], - positive_target_with_neutral=cache[f"{target.negative} {neutral}"], - negative_target=cache[f"{target.positive}"], - negative_target_with_neutral=cache[f"{target.positive} {neutral}"], - neutral=cache[neutral], - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - both_targets=cache[f"{target.positive} {target.negative}"], - empty_prompt=cache[""], - multiplier=target.multiplier * -1.0, - weight=target.weight - ), - ] - # if both or erase_negative: - if both: - print("Encoding enhance negative (inverse)") - prompt_pairs += [ - # enhance inverted - EncodedPromptPair( - target_class=cache[target.target_class], - target_class_with_neutral=cache[f"{target.target_class} {neutral}"], - positive_target=cache[f"{target.positive}"], - positive_target_with_neutral=cache[f"{target.positive} {neutral}"], - negative_target=cache[f"{target.negative}"], - negative_target_with_neutral=cache[f"{target.negative} {neutral}"], - both_targets=cache[f"{target.positive} {target.negative}"], - neutral=cache[neutral], - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - empty_prompt=cache[""], - multiplier=target.multiplier * -1.0, - weight=target.weight - ), - ] + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] # setup anchors anchor_pairs = [] @@ -306,13 +137,26 @@ def hook_before_train_loop(self): if cache[prompt] == None: cache[prompt] = self.sd.encode_prompt(prompt) + anchor_batch = [] + # we get the prompt pair multiplier from first prompt pair + # since they are all the same. We need to match their network polarity + prompt_pair_multipliers = prompt_pairs[0].multiplier_list + for prompt_multiplier in prompt_pair_multipliers: + # match the network multiplier polarity + anchor_scalar = 1.0 if prompt_multiplier > 0 else -1.0 + anchor_batch += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier * anchor_scalar + ) + ] + anchor_pairs += [ - EncodedAnchor( - prompt=cache[anchor.prompt], - neg_prompt=cache[anchor.neg_prompt], - multiplier=anchor.multiplier - ) + concat_anchors(anchor_batch).to('cpu') ] + if len(anchor_pairs) > 0: + self.anchor_pairs = anchor_pairs # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling @@ -324,17 +168,13 @@ def hook_before_train_loop(self): self.sd.text_encoder.to("cpu") self.prompt_cache = cache self.prompt_pairs = prompt_pairs - self.anchor_pairs = anchor_pairs + # self.anchor_pairs = anchor_pairs flush() # end hook_before_train_loop def hook_train_loop(self): dtype = get_torch_dtype(self.train_config.dtype) - # get random multiplier between 1 and 3 - rand_weight = 1 - # rand_weight = torch.rand((1,)).item() * 2 + 1 - # get a random pair prompt_pair: EncodedPromptPair = self.prompt_pairs[ torch.randint(0, len(self.prompt_pairs), (1,)).item() @@ -346,11 +186,10 @@ def hook_train_loop(self): height, width = self.slider_config.resolutions[ torch.randint(0, len(self.slider_config.resolutions), (1,)).item() ] + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() - weight = prompt_pair.weight - multiplier = prompt_pair.multiplier - - unet = self.sd.unet noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer lr_scheduler = self.lr_scheduler @@ -368,9 +207,6 @@ def get_noise_pred(neg, pos, gs, cts, dn): guidance_scale=gs, ) - # set network multiplier - self.network.multiplier = multiplier * rand_weight - with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -383,11 +219,14 @@ def get_noise_pred(neg, pos, gs, cts, dn): 1, self.train_config.max_denoising_steps, (1,) ).item() + # for a complete slider, the batch size is 4 to begin with now + true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size + # get noise noise = self.sd.get_latent_noise( pixel_height=height, pixel_width=width, - batch_size=self.train_config.batch_size, + batch_size=true_batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) @@ -397,7 +236,8 @@ def get_noise_pred(neg, pos, gs, cts, dn): with self.network: assert self.network.is_active - self.network.multiplier = multiplier * rand_weight + # pass the multiplier list to the network + self.network.multiplier = prompt_pair.multiplier_list denoised_latents = self.sd.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( @@ -410,19 +250,27 @@ def get_noise_pred(neg, pos, gs, cts, dn): guidance_scale=3, ) + # split the latents into out prompt pair chunks + denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + noise_scheduler.set_timesteps(1000) current_timestep = noise_scheduler.timesteps[ int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] + # flush() # 4.2GB to 3GB on 512x512 + + # 4.20 GB RAM for 512x512 positive_latents = get_noise_pred( prompt_pair.positive_target, # negative prompt prompt_pair.negative_target, # positive prompt 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) + ) + positive_latents.requires_grad = False + positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) neutral_latents = get_noise_pred( prompt_pair.positive_target, # negative prompt @@ -430,7 +278,9 @@ def get_noise_pred(neg, pos, gs, cts, dn): 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) + ) + neutral_latents.requires_grad = False + neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) unconditional_latents = get_noise_pred( prompt_pair.positive_target, # negative prompt @@ -438,87 +288,142 @@ def get_noise_pred(neg, pos, gs, cts, dn): 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) + ) + unconditional_latents.requires_grad = False + unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) + + flush() # 4.2GB to 3GB on 512x512 - anchor_loss = None + # 4.20 GB RAM for 512x512 + anchor_loss_float = None if len(self.anchor_pairs) > 0: - # get a random anchor pair - anchor: EncodedAnchor = self.anchor_pairs[ - torch.randint(0, len(self.anchor_pairs), (1,)).item() - ] with torch.no_grad(): - anchor_target_noise = get_noise_pred( - anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - with self.network: - # anchor whatever weight prompt pair is using - pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 - self.network.multiplier = anchor.multiplier * pos_nem_mult * rand_weight + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + anchor.to(self.device_torch, dtype=dtype) - anchor_pred_noise = get_noise_pred( - anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) + # first we get the target prediction without network active + anchor_target_noise = get_noise_pred( + anchor.neg_prompt, anchor.prompt, 1, current_timestep, denoised_latents + # ).to("cpu", dtype=torch.float32) + ).requires_grad_(False) - self.network.multiplier = prompt_pair.multiplier * rand_weight + # to save vram, we will run these through separately while tracking grads + # otherwise it consumes a ton of vram and this isn't our speed bottleneck + anchor_chunks = split_anchors(anchor, self.prompt_chunk_size) + anchor_target_noise_chunks = torch.chunk(anchor_target_noise, self.prompt_chunk_size, dim=0) + assert len(anchor_chunks) == len(denoised_latent_chunks) + # 4.32 GB RAM for 512x512 + with self.network: + assert self.network.is_active + anchor_float_losses = [] + for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip( + anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks + ): + self.network.multiplier = anchor_chunk.multiplier_list + + anchor_pred_noise = get_noise_pred( + anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk + ) + # 9.42 GB RAM for 512x512 -> 4.20 GB RAM for 512x512 with new grad_checkpointing + anchor_loss = loss_function( + anchor_target_noise_chunk, + anchor_pred_noise, + ) + anchor_float_losses.append(anchor_loss.item()) + # compute anchor loss gradients + # we will accumulate them later + # this saves a ton of memory doing them separately + anchor_loss.backward() + del anchor_pred_noise + del anchor_target_noise_chunk + del anchor_loss + flush() + + anchor_loss_float = sum(anchor_float_losses) / len(anchor_float_losses) + del anchor_chunks + del anchor_target_noise_chunks + del anchor_target_noise + # move anchor back to cpu + anchor.to("cpu") + flush() + + prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) + assert len(prompt_pair_chunks) == len(denoised_latent_chunks) + # 3.28 GB RAM for 512x512 with self.network: - self.network.multiplier = prompt_pair.multiplier * rand_weight - target_latents = get_noise_pred( - prompt_pair.positive_target, - prompt_pair.target_class, - 1, - current_timestep, - denoised_latents - ).to("cpu", dtype=torch.float32) - - # if self.logging_config.verbose: - # self.print("target_latents:", target_latents[0, 0, :5, :5]) - - positive_latents.requires_grad = False - neutral_latents.requires_grad = False - unconditional_latents.requires_grad = False - if len(self.anchor_pairs) > 0: - anchor_target_noise.requires_grad = False - anchor_loss = loss_function( - anchor_target_noise, - anchor_pred_noise, - ) - erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE - guidance_scale = 1.0 + assert self.network.is_active + loss_list = [] + for prompt_pair_chunk, \ + denoised_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk \ + in zip( + prompt_pair_chunks, + denoised_latent_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + target_latents = get_noise_pred( + prompt_pair_chunk.positive_target, + prompt_pair_chunk.target_class, + 1, + current_timestep, + denoised_latent_chunk + ) - offset = guidance_scale * (positive_latents - unconditional_latents) + guidance_scale = 1.0 - offset_neutral = neutral_latents - if erase: - offset_neutral -= offset - else: - # enhance - offset_neutral += offset + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) - loss = loss_function( - target_latents, - offset_neutral, - ) * weight + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] - loss_slide = loss.item() + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier - if anchor_loss is not None: - loss += anchor_loss + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset - loss_float = loss.item() + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = loss_function( + target_latents, + offset_neutral, + ) * prompt_pair_chunk.weight - loss = loss.to(self.device_torch) + loss.backward() + loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + flush() - loss.backward() optimizer.step() lr_scheduler.step() + loss_float = sum(loss_list) / len(loss_list) + if anchor_loss_float is not None: + loss_float += anchor_loss_float + del ( positive_latents, neutral_latents, unconditional_latents, - target_latents, - latents, + latents ) # move back to cpu prompt_pair.to("cpu") @@ -530,9 +435,9 @@ def get_noise_pred(neg, pos, gs, cts, dn): loss_dict = OrderedDict( {'loss': loss_float}, ) - if anchor_loss is not None: - loss_dict['sl_l'] = loss_slide - loss_dict['an_l'] = anchor_loss.item() + if anchor_loss_float is not None: + loss_dict['sl_l'] = loss_float + loss_dict['an_l'] = anchor_loss_float return loss_dict # end hook_train_loop diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 16861140..f84273d9 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -108,6 +108,7 @@ def __init__(self, **kwargs): self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None) + self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) class GenerateImageConfig: diff --git a/toolkit/layers.py b/toolkit/layers.py index 2d6aaecb..dfc975bf 100644 --- a/toolkit/layers.py +++ b/toolkit/layers.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import numpy as np +from torch.utils.checkpoint import checkpoint class ReductionKernel(nn.Module): @@ -29,3 +30,15 @@ def build_kernel(self): def forward(self, x): return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) + + +class CheckpointGradients(nn.Module): + def __init__(self, is_gradient_checkpointing=True): + super(CheckpointGradients, self).__init__() + self.is_gradient_checkpointing = is_gradient_checkpointing + + def forward(self, module, *args, num_chunks=1): + if self.is_gradient_checkpointing: + return checkpoint(module, *args, num_chunks=self.num_chunks) + else: + return module(*args) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 5922d2cc..35c4223a 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -1,4 +1,6 @@ +import math import os +import re import sys from typing import List, Optional, Dict, Type, Union @@ -9,7 +11,170 @@ sys.path.append(SD_SCRIPTS_ROOT) -from networks.lora import LoRANetwork, LoRAModule, get_block_index +from networks.lora import LoRANetwork, get_block_index + +from torch.utils.checkpoint import checkpoint + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier: Union[float, List[float]] = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + # this allows us to set different multipliers on a per item in a batch basis + # allowing us to run positive and negative weights in the same batch + # really only useful for slider training for now + def get_multiplier(self, lora_up): + batch_size = lora_up.size(0) + # batch will have all negative prompts first and positive prompts second + # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts + # if there is more than our multiplier, it is liekly a batch size increase, so we need to + # interleve the multipliers + if isinstance(self.multiplier, list): + if len(self.multiplier) == 0: + # single item, just return it + return self.multiplier[0] + else: + # we have a list of multipliers, so we need to get the multiplier for this batch + multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) + # should be 1 for if total batch size was 1 + num_interleaves = (batch_size // 2) // len(self.multiplier) + multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) + + # match lora_up rank + if len(lora_up.size()) == 2: + multiplier_tensor = multiplier_tensor.view(-1, 1) + elif len(lora_up.size()) == 3: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1) + elif len(lora_up.size()) == 4: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) + return multiplier_tensor + + else: + return self.multiplier + + def _call_forward(self, x): + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return 0.0 # added to original forward + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + multiplier = self.get_multiplier(lx) + + return lx * multiplier * scale + + def create_custom_forward(self): + def custom_forward(*inputs): + return self._call_forward(*inputs) + + return custom_forward + + def forward(self, x): + org_forwarded = self.org_forward(x) + # TODO this just loses the grad. Not sure why. Probably why no one else is doing it either + # if torch.is_grad_enabled() and self.is_checkpointing and self.training: + # lora_output = checkpoint( + # self.create_custom_forward(), + # x, + # ) + # else: + # lora_output = self._call_forward(x) + + lora_output = self._call_forward(x) + + return org_forwarded + lora_output + + def enable_gradient_checkpointing(self): + self.is_checkpointing = True + + def disable_gradient_checkpointing(self): + self.is_checkpointing = False class LoRASpecialNetwork(LoRANetwork): @@ -70,6 +235,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.is_checkpointing = False if modules_dim is not None: print(f"create LoRA network from weights") @@ -236,14 +402,11 @@ def save_weights(self, file, dtype, metadata): torch.save(state_dict, file) @property - def multiplier(self): + def multiplier(self) -> Union[float, List[float]]: return self._multiplier @multiplier.setter - def multiplier(self, value): - # only update if changed - if self._multiplier == value: - return + def multiplier(self, value: Union[float, List[float]]): self._multiplier = value self._update_lora_multiplier() @@ -264,6 +427,8 @@ def _update_lora_multiplier(self): for lora in self.text_encoder_loras: lora.multiplier = 0 + # called when the context manager is entered + # ie: with network: def __enter__(self): self.is_active = True self._update_lora_multiplier() @@ -281,3 +446,29 @@ def force_to(self, device, dtype): loras += self.text_encoder_loras for lora in loras: lora.to(device, dtype) + + def _update_checkpointing(self): + if self.is_checkpointing: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.enable_gradient_checkpointing() + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.enable_gradient_checkpointing() + else: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.disable_gradient_checkpointing() + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.disable_gradient_checkpointing() + + def enable_gradient_checkpointing(self): + # not supported + self.is_checkpointing = True + self._update_checkpointing() + + def disable_gradient_checkpointing(self): + # not supported + self.is_checkpointing = False + self._update_checkpointing() diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py new file mode 100644 index 00000000..215d150e --- /dev/null +++ b/toolkit/prompt_utils.py @@ -0,0 +1,387 @@ +import os +from typing import Optional, TYPE_CHECKING, List + +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from toolkit.stable_diffusion_model import PromptEmbeds +from toolkit.train_tools import get_torch_dtype + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +class EncodedPromptPair: + def __init__( + self, + target_class, + target_class_with_neutral, + positive_target, + positive_target_with_neutral, + negative_target, + negative_target_with_neutral, + neutral, + empty_prompt, + both_targets, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + action_list=None, + multiplier=1.0, + multiplier_list=None, + weight=1.0 + ): + self.target_class: PromptEmbeds = target_class + self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral + self.positive_target: PromptEmbeds = positive_target + self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral + self.negative_target: PromptEmbeds = negative_target + self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral + self.neutral: PromptEmbeds = neutral + self.empty_prompt: PromptEmbeds = empty_prompt + self.both_targets: PromptEmbeds = both_targets + self.multiplier: float = multiplier + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + self.action: int = action + if action_list is not None: + self.action_list: list[int] = action_list + else: + self.action_list: list[int] = [action] + self.weight: float = weight + + # simulate torch to for tensors + def to(self, *args, **kwargs): + self.target_class = self.target_class.to(*args, **kwargs) + self.positive_target = self.positive_target.to(*args, **kwargs) + self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) + self.negative_target = self.negative_target.to(*args, **kwargs) + self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) + self.neutral = self.neutral.to(*args, **kwargs) + self.empty_prompt = self.empty_prompt.to(*args, **kwargs) + self.both_targets = self.both_targets.to(*args, **kwargs) + return self + + +def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): + text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + pooled_embeds = None + if prompt_embeds[0].pooled_embeds is not None: + pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) + return PromptEmbeds([text_embeds, pooled_embeds]) + + +def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): + weight = prompt_pairs[0].weight + target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs]) + target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs]) + positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs]) + positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs]) + negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs]) + negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs]) + neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs]) + empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs]) + both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs]) + # combine all the lists + action_list = [] + multiplier_list = [] + weight_list = [] + for p in prompt_pairs: + action_list += p.action_list + multiplier_list += p.multiplier_list + return EncodedPromptPair( + target_class=target_class, + target_class_with_neutral=target_class_with_neutral, + positive_target=positive_target, + positive_target_with_neutral=positive_target_with_neutral, + negative_target=negative_target, + negative_target_with_neutral=negative_target_with_neutral, + neutral=neutral, + empty_prompt=empty_prompt, + both_targets=both_targets, + action_list=action_list, + multiplier_list=multiplier_list, + weight=weight + ) + + +def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]: + if num_parts is None: + # use batch size + num_parts = concatenated.text_embeds.shape[0] + text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) + + if concatenated.pooled_embeds is not None: + pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0) + else: + pooled_embeds_splits = [None] * num_parts + + prompt_embeds_list = [ + PromptEmbeds([text, pooled]) + for text, pooled in zip(text_embeds_splits, pooled_embeds_splits) + ] + + return prompt_embeds_list + + +def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]: + target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds) + target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds) + positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds) + positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds) + negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds) + negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds) + neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds) + empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds) + both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds) + + prompt_pairs = [] + for i in range(len(target_class_splits)): + action_list_split = concatenated.action_list[i::len(target_class_splits)] + multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)] + + prompt_pair = EncodedPromptPair( + target_class=target_class_splits[i], + target_class_with_neutral=target_class_with_neutral_splits[i], + positive_target=positive_target_splits[i], + positive_target_with_neutral=positive_target_with_neutral_splits[i], + negative_target=negative_target_splits[i], + negative_target_with_neutral=negative_target_with_neutral_splits[i], + neutral=neutral_splits[i], + empty_prompt=empty_prompt_splits[i], + both_targets=both_targets_splits[i], + action_list=action_list_split, + multiplier_list=multiplier_list_split, + weight=concatenated.weight + ) + prompt_pairs.append(prompt_pair) + + return prompt_pairs + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0, + multiplier_list=None + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + + def to(self, *args, **kwargs): + self.prompt = self.prompt.to(*args, **kwargs) + self.neg_prompt = self.neg_prompt.to(*args, **kwargs) + return self + + +def concat_anchors(anchors: list[EncodedAnchor]): + prompt = concat_prompt_embeds([a.prompt for a in anchors]) + neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors]) + return EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier_list=[a.multiplier for a in anchors] + ) + + +def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]: + prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors) + neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors) + multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors) + + anchors = [] + for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits): + anchor = EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier=multiplier.tolist() + ) + anchors.append(anchor) + + return anchors + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +@torch.no_grad() +def encode_prompts_to_cache( + prompt_list: list[str], + sd: "StableDiffusion", + cache: Optional[PromptEmbedsCache] = None, + prompt_tensor_file: Optional[str] = None, +) -> PromptEmbedsCache: + # TODO: add support for larger prompts + if cache is None: + cache = PromptEmbedsCache() + + if prompt_tensor_file is not None: + # check to see if it exists + if os.path.exists(prompt_tensor_file): + # load it. + print(f"Loading prompt tensors from {prompt_tensor_file}") + prompt_tensors = load_file(prompt_tensor_file, device='cpu') + # add them to the cache + for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): + if prompt_txt.startswith("te:"): + prompt = prompt_txt[3:] + # text_embeds + text_embeds = prompt_tensor + pooled_embeds = None + # find pool embeds + if f"pe:{prompt}" in prompt_tensors: + pooled_embeds = prompt_tensors[f"pe:{prompt}"] + + # make it + prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) + cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) + + if len(cache.prompts) == 0: + print("Prompt tensors not found. Encoding prompts..") + empty_prompt = "" + # encode empty_prompt + cache[empty_prompt] = sd.encode_prompt(empty_prompt) + + for p in tqdm(prompt_list, desc="Encoding prompts", leave=False): + # build the cache + if cache[p] is None: + cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16) + + # should we shard? It can get large + if prompt_tensor_file: + print(f"Saving prompt tensors to {prompt_tensor_file}") + state_dict = {} + for prompt_txt, prompt_embeds in cache.prompts.items(): + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to( + "cpu", dtype=get_torch_dtype('fp16') + ) + if prompt_embeds.pooled_embeds is not None: + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to( + "cpu", + dtype=get_torch_dtype('fp16') + ) + save_file(state_dict, prompt_tensor_file) + + return cache + + +if TYPE_CHECKING: + from toolkit.config_modules import SliderTargetConfig + + +@torch.no_grad() +def build_prompt_pair_batch_from_cache( + cache: PromptEmbedsCache, + target: 'SliderTargetConfig', + neutral: Optional[str] = '', +) -> list[EncodedPromptPair]: + erase_negative = len(target.positive.strip()) == 0 + enhance_positive = len(target.negative.strip()) == 0 + + both = not erase_negative and not enhance_positive + + prompt_pair_batch = [] + + if both or erase_negative: + print("Encoding erase negative") + prompt_pair_batch += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight + ), + ] + if both or enhance_positive: + print("Encoding enhance positive") + prompt_pair_batch += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight + ), + ] + if both or enhance_positive: + print("Encoding erase positive (inverse)") + prompt_pair_batch += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + if both or erase_negative: + print("Encoding enhance negative (inverse)") + prompt_pair_batch += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + both_targets=cache[f"{target.positive} {target.negative}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + + return prompt_pair_batch diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 2e924d19..f06f2929 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,6 +1,6 @@ import gc import typing -from typing import Union, OrderedDict, List +from typing import Union, OrderedDict, List, Tuple import sys import os @@ -50,10 +50,10 @@ def flush(): class PromptEmbeds: - text_embeds: torch.FloatTensor - pooled_embeds: Union[torch.FloatTensor, None] + text_embeds: torch.Tensor + pooled_embeds: Union[torch.Tensor, None] - def __init__(self, args) -> None: + def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None: if isinstance(args, list) or isinstance(args, tuple): # xl self.text_embeds = args[0] @@ -139,12 +139,23 @@ def load_model(self): pipln = self.custom_pipeline else: pipln = CustomStableDiffusionXLPipeline - pipe = pipln.from_single_file( - self.model_config.name_or_path, - dtype=dtype, - scheduler_type='ddpm', - device=self.device_torch, - ).to(self.device_torch) + + # see if path exists + if not os.path.exists(self.model_config.name_or_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='ddpm', + device=self.device_torch, + ).to(self.device_torch) + else: + pipe = pipln.from_single_file( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='ddpm', + device=self.device_torch, + ).to(self.device_torch) text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] @@ -158,14 +169,27 @@ def load_model(self): pipln = self.custom_pipeline else: pipln = CustomStableDiffusionPipeline - pipe = pipln.from_single_file( - self.model_config.name_or_path, - dtype=dtype, - scheduler_type='dpm', - device=self.device_torch, - load_safety_checker=False, - requires_safety_checker=False, - ).to(self.device_torch) + + # see if path exists + if not os.path.exists(self.model_config.name_or_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='dpm', + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + ).to(self.device_torch) + else: + pipe = pipln.from_single_file( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='dpm', + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + ).to(self.device_torch) pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder text_encoder.to(self.device_torch, dtype=dtype)