Skip to content

Commit

Permalink
Added anchors to regulate the lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Jul 24, 2023
1 parent 390192c commit 61dd818
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 30 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,18 @@ Just went in and out. It is much worse on smaller faces than shown here.

<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">

---

## TODO
- [ ] Add proper regs on sliders
- [X] Add proper regs on sliders
- [ ] Add SDXL support (base model only for now)
- [ ] Add plain erasing
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)

---

## Change Log
#### 2021-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights

19 changes: 19 additions & 0 deletions config/examples/train_slider.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ config:
# to a lower number like 0.1 so they dont outweigh the primary target
weight: 1.0

# anchors are prompts that wer try to hold on to while training the slider
# you want these to generate an image very similar to the target_class
# without directly overlapping it. For example, if you are training on a person smiling,
# you would use "a person with a face mask" as an anchor. It is a person, the image is the same
# regardless if they are smiling or not
anchors:
# only positive prompt for now
- prompt: "a woman"
neg_prompt: "animal"
# the multiplier applied to the LoRA when this is run.
# higher will give it more weight but also help keep the lora from collapsing
multiplier: 8.0
- prompt: "a man"
neg_prompt: "animal"
multiplier: 8.0
- prompt: "a person"
neg_prompt: "animal"
multiplier: 8.0

# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this
Expand Down
54 changes: 41 additions & 13 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
import os

from leco.train_util import predict_noise
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
Expand Down Expand Up @@ -59,11 +60,10 @@ def __init__(self, process_id: int, job, config: OrderedDict):
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
self.sd = None
self.sd: 'StableDiffusion' = None

# added later
self.network = None
self.scheduler = None

def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
Expand Down Expand Up @@ -118,7 +118,7 @@ def sample(self, step=None):
'multiplier': self.network.multiplier,
})

for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
raw_prompt = self.sample_config.prompts[i]

neg = self.sample_config.neg
Expand Down Expand Up @@ -267,6 +267,27 @@ def hook_train_loop(self):
# return loss
return 0.0

# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor,
total_timesteps: int = 1000,
start_timesteps=0,
**kwargs,
):

for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = train_util.predict_noise(
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
)

# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample

# return latents_steps
return latents

def run(self):
super().run()

Expand Down Expand Up @@ -368,7 +389,21 @@ def run(self):
# todo handle dataloader here maybe, not sure

### HOOK ###
loss = self.hook_train_loop()
loss_dict = self.hook_train_loop()

if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']

prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"

self.progress_bar.set_postfix_str(prog_bar_string)

# don't do on first step
if self.step_num != self.start_step:
Expand All @@ -386,15 +421,8 @@ def run(self):
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
self.writer.add_scalar(f"loss", loss, self.step_num)
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()

Expand Down
112 changes: 97 additions & 15 deletions jobs/process/TrainSliderProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def flush():
gc.collect()



class EncodedPromptPair:
def __init__(
self,
Expand All @@ -53,6 +54,18 @@ def __init__(
self.weight = weight


class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier


class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
Expand All @@ -61,9 +74,9 @@ def __init__(self, process_id: int, job, config: OrderedDict):
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.slider_config = SliderConfig(**self.get_conf('slider', {}))

self.prompt_cache = PromptEmbedsCache()
self.prompt_pairs: list[EncodedPromptPair] = []
self.anchor_pairs: list[EncodedAnchor] = []

def before_model_load(self):
pass
Expand Down Expand Up @@ -146,16 +159,39 @@ def hook_before_train_loop(self):
),
]

# setup anchors
anchor_pairs = []
for anchor in self.slider_config.anchors:
# build the cache
for prompt in [
anchor.prompt,
anchor.neg_prompt # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)

anchor_pairs += [
EncodedAnchor(
prompt=cache[anchor.prompt],
neg_prompt=cache[anchor.neg_prompt],
multiplier=anchor.multiplier
)
]

# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs
flush()
# end hook_before_train_loop

def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)

# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
Expand Down Expand Up @@ -202,10 +238,7 @@ def hook_train_loop(self):

with self.network:
assert self.network.is_active
# A little denoised one is returned
denoised_latents = train_util.diffusion(
unet,
noise_scheduler,
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_util.concat_embeddings(
positive, # unconditional
Expand Down Expand Up @@ -261,7 +294,46 @@ def hook_train_loop(self):
guidance_scale=1,
).to("cpu", dtype=torch.float32)

anchor_loss = None
if len(self.anchor_pairs) > 0:
# get a random anchor pair
anchor: EncodedAnchor = self.anchor_pairs[
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)

self.network.multiplier = prompt_pair.multiplier

with self.network:
self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
Expand All @@ -281,7 +353,12 @@ def hook_train_loop(self):
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False

if len(self.anchor_pairs) > 0:
anchor_target_noise.requires_grad = False
anchor_loss = loss_function(
anchor_target_noise,
anchor_pred_noise,
)
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
guidance_scale = 1.0

Expand All @@ -299,16 +376,14 @@ def hook_train_loop(self):
offset_neutral,
) * weight

loss_slide = loss.item()

if anchor_loss is not None:
loss += anchor_loss

loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']

self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
loss = loss.to(self.device_torch)

loss.backward()
optimizer.step()
Expand All @@ -326,5 +401,12 @@ def hook_train_loop(self):
# reset network
self.network.multiplier = 1.0

return loss_float
loss_dict = OrderedDict(
{'loss': loss_float},
)
if anchor_loss is not None:
loss_dict['sl_l'] = loss_slide
loss_dict['an_l'] = anchor_loss.item()

return loss_dict
# end hook_train_loop
10 changes: 10 additions & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,19 @@ def __init__(self, **kwargs):
self.weight: float = kwargs.get('weight', 1.0)


class SliderConfigAnchors:
def __init__(self, **kwargs):
self.prompt = kwargs.get('prompt', '')
self.neg_prompt = kwargs.get('neg_prompt', '')
self.multiplier = kwargs.get('multiplier', 1.0)


class SliderConfig:
def __init__(self, **kwargs):
targets = kwargs.get('targets', [])
targets = [SliderTargetConfig(**target) for target in targets]
self.targets: List[SliderTargetConfig] = targets
anchors = kwargs.get('anchors', [])
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
self.anchors: List[SliderConfigAnchors] = anchors
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])

0 comments on commit 61dd818

Please sign in to comment.