diff --git a/README.md b/README.md index 4fabac63..55f02c12 100644 --- a/README.md +++ b/README.md @@ -104,8 +104,18 @@ Just went in and out. It is much worse on smaller faces than shown here. +--- + ## TODO -- [ ] Add proper regs on sliders +- [X] Add proper regs on sliders - [ ] Add SDXL support (base model only for now) - [ ] Add plain erasing -- [ ] Make Textual inversion network trainer (network that spits out TI embeddings) \ No newline at end of file +- [ ] Make Textual inversion network trainer (network that spits out TI embeddings) + +--- + +## Change Log +#### 2021-07-30 +Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a +regularizer. You can set the network multiplier to force spread consistency at high weights + diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index 34c12e5f..c339e5aa 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -163,6 +163,25 @@ config: # to a lower number like 0.1 so they dont outweigh the primary target weight: 1.0 + # anchors are prompts that wer try to hold on to while training the slider + # you want these to generate an image very similar to the target_class + # without directly overlapping it. For example, if you are training on a person smiling, + # you would use "a person with a face mask" as an anchor. It is a person, the image is the same + # regardless if they are smiling or not + anchors: + # only positive prompt for now + - prompt: "a woman" + neg_prompt: "animal" + # the multiplier applied to the LoRA when this is run. + # higher will give it more weight but also help keep the lora from collapsing + multiplier: 8.0 + - prompt: "a man" + neg_prompt: "animal" + multiplier: 8.0 + - prompt: "a person" + neg_prompt: "animal" + multiplier: 8.0 + # You can put any information you want here, and it will be saved in the model. # The below is an example, but you can put your grocery list in it if you want. # It is saved in the model so be aware of that. The software will include this diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index a337bf3f..79774f8f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2,6 +2,7 @@ from collections import OrderedDict import os +from leco.train_util import predict_noise from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -59,11 +60,10 @@ def __init__(self, process_id: int, job, config: OrderedDict): self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None - self.sd = None + self.sd: 'StableDiffusion' = None # added later self.network = None - self.scheduler = None def sample(self, step=None): sample_folder = os.path.join(self.save_root, 'samples') @@ -118,7 +118,7 @@ def sample(self, step=None): 'multiplier': self.network.multiplier, }) - for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"): + for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False): raw_prompt = self.sample_config.prompts[i] neg = self.sample_config.neg @@ -267,6 +267,27 @@ def hook_train_loop(self): # return loss return 0.0 + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: torch.FloatTensor, + total_timesteps: int = 1000, + start_timesteps=0, + **kwargs, + ): + + for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): + noise_pred = train_util.predict_noise( + self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents + def run(self): super().run() @@ -368,7 +389,21 @@ def run(self): # todo handle dataloader here maybe, not sure ### HOOK ### - loss = self.hook_train_loop() + loss_dict = self.hook_train_loop() + + if self.train_config.optimizer.startswith('dadaptation'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" + + self.progress_bar.set_postfix_str(prog_bar_string) # don't do on first step if self.step_num != self.start_step: @@ -386,15 +421,8 @@ def run(self): if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: # log to tensorboard if self.writer is not None: - # get avg loss - self.writer.add_scalar(f"loss", loss, self.step_num) - if self.train_config.optimizer.startswith('dadaptation'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) self.writer.add_scalar(f"lr", learning_rate, self.step_num) self.progress_bar.refresh() diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 3aa879bd..70ddbcd6 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -29,6 +29,7 @@ def flush(): gc.collect() + class EncodedPromptPair: def __init__( self, @@ -53,6 +54,18 @@ def __init__( self.weight = weight +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0 + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + class TrainSliderProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -61,9 +74,9 @@ def __init__(self, process_id: int, job, config: OrderedDict): self.device = self.get_conf('device', self.job.device) self.device_torch = torch.device(self.device) self.slider_config = SliderConfig(**self.get_conf('slider', {})) - self.prompt_cache = PromptEmbedsCache() self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] def before_model_load(self): pass @@ -146,16 +159,39 @@ def hook_before_train_loop(self): ), ] + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = train_util.encode_prompts( + self.sd.tokenizer, self.sd.text_encoder, [prompt] + ) + + anchor_pairs += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier + ) + ] + # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling self.sd.text_encoder.to("cpu") self.prompt_cache = cache self.prompt_pairs = prompt_pairs + self.anchor_pairs = anchor_pairs flush() # end hook_before_train_loop def hook_train_loop(self): dtype = get_torch_dtype(self.train_config.dtype) + # get a random pair prompt_pair: EncodedPromptPair = self.prompt_pairs[ torch.randint(0, len(self.prompt_pairs), (1,)).item() @@ -202,10 +238,7 @@ def hook_train_loop(self): with self.network: assert self.network.is_active - # A little denoised one is returned - denoised_latents = train_util.diffusion( - unet, - noise_scheduler, + denoised_latents = self.diffuse_some_steps( latents, # pass simple noise latents train_util.concat_embeddings( positive, # unconditional @@ -261,7 +294,46 @@ def hook_train_loop(self): guidance_scale=1, ).to("cpu", dtype=torch.float32) + anchor_loss = None + if len(self.anchor_pairs) > 0: + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + with torch.no_grad(): + anchor_target_noise = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + anchor.prompt, + anchor.neg_prompt, + self.train_config.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + with self.network: + # anchor whatever weight prompt pair is using + pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 + self.network.multiplier = anchor.multiplier * pos_nem_mult + anchor_pred_noise = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + anchor.prompt, + anchor.neg_prompt, + self.train_config.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + + self.network.multiplier = prompt_pair.multiplier + with self.network: + self.network.multiplier = prompt_pair.multiplier target_latents = train_util.predict_noise( unet, noise_scheduler, @@ -281,7 +353,12 @@ def hook_train_loop(self): positive_latents.requires_grad = False neutral_latents.requires_grad = False unconditional_latents.requires_grad = False - + if len(self.anchor_pairs) > 0: + anchor_target_noise.requires_grad = False + anchor_loss = loss_function( + anchor_target_noise, + anchor_pred_noise, + ) erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE guidance_scale = 1.0 @@ -299,16 +376,14 @@ def hook_train_loop(self): offset_neutral, ) * weight + loss_slide = loss.item() + + if anchor_loss is not None: + loss += anchor_loss + loss_float = loss.item() - if self.train_config.optimizer.startswith('dadaptation'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] - self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}") + loss = loss.to(self.device_torch) loss.backward() optimizer.step() @@ -326,5 +401,12 @@ def hook_train_loop(self): # reset network self.network.multiplier = 1.0 - return loss_float + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss is not None: + loss_dict['sl_l'] = loss_slide + loss_dict['an_l'] = anchor_loss.item() + + return loss_dict # end hook_train_loop diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 436abb1b..4a22f331 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -71,9 +71,19 @@ def __init__(self, **kwargs): self.weight: float = kwargs.get('weight', 1.0) +class SliderConfigAnchors: + def __init__(self, **kwargs): + self.prompt = kwargs.get('prompt', '') + self.neg_prompt = kwargs.get('neg_prompt', '') + self.multiplier = kwargs.get('multiplier', 1.0) + + class SliderConfig: def __init__(self, **kwargs): targets = kwargs.get('targets', []) targets = [SliderTargetConfig(**target) for target in targets] self.targets: List[SliderTargetConfig] = targets + anchors = kwargs.get('anchors', []) + anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] + self.anchors: List[SliderConfigAnchors] = anchors self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])