Skip to content

Commit

Permalink
Complete reqork of how slider training works and optimized it to hell…
Browse files Browse the repository at this point in the history
…. Can run entire algorythm in 1 batch now with less VRAM consumption than a quarter of it used to take
  • Loading branch information
jaretburkett committed Aug 6, 2023
1 parent 7e4e660 commit 8c90fa8
Show file tree
Hide file tree
Showing 10 changed files with 942 additions and 377 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,27 @@ Just went in and out. It is much worse on smaller faces than shown here.

## Change Log

#### 2023-08-05
- Huge memory rework and slider rework. Slider training is better thant ever with no more
ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
accumulation. This makes it much faster and more stable.
- Updated the example config to be something more practical and more updated to current methods. It is now
a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
6GB gpu now. Will test soon to verify.


#### 2021-10-20
- Windows support bug fixes
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
check out the example in the `extensions` folder. Read more about that above.
- Model Merging, provided via the example extension.

#### 2021-08-03
#### 2023-08-03
Another big refactor to make SD more modular.

Made batch image generation script

#### 2021-08-01
#### 2023-08-01
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
at the moment, so hopefully there are not breaking changes.
Expand All @@ -199,7 +208,7 @@ encoders to the model as well as a few more entirely separate diffusion networks
training without every experimental new paper added to it. The KISS principal.


#### 2021-07-30
#### 2023-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights

104 changes: 65 additions & 39 deletions config/examples/train_slider.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ job: train
config:
# the name will be used to create a folder in the output folder
# it will also replace any [name] token in the rest of this config
name: pet_slider_v1
name: detail_slider_v1
# folder will be created with name above in folder below
# it can be relative to the project root or absolute
training_folder: "output/LoRA"
Expand All @@ -24,7 +24,7 @@ config:
type: "lierla"
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
rank: 8
alpha: 1.0 # just leave it
alpha: 4 # Do about half of rank

# training config
train:
Expand All @@ -33,7 +33,7 @@ config:
# how many steps to train. More is not always better. I rarely go over 1000
steps: 500
# I have had good results with 4e-4 to 1e-4 at 500 steps
lr: 1e-4
lr: 2e-4
# enables gradient checkpoint, saves vram, leave it on
gradient_checkpointing: true
# train the unet. I recommend leaving this true
Expand All @@ -43,6 +43,7 @@ config:
# not the description of it (text encoder)
train_text_encoder: false


# just leave unless you know what you are doing
# also supports "dadaptation" but set lr to 1 if you use that,
# but it learns too fast and I don't recommend it
Expand All @@ -53,6 +54,7 @@ config:
# while training. Just leave it
max_denoising_steps: 40
# works great at 1. I do 1 even with my 4090.
# higher may not work right with newer single batch stacking code anyway
batch_size: 1
# bf16 works best if your GPU supports it (modern)
dtype: bf16 # fp32, bf16, fp16
Expand All @@ -69,12 +71,17 @@ config:
name_or_path: "runwayml/stable-diffusion-v1-5"
is_v2: false # for v2 models
is_v_pred: false # for v-prediction models (most v2 models)
# has some issues with the dual text encoder and the way we train sliders
# it works bit weights need to probably be higher to see it.
is_xl: false # for SDXL models

# saving config
save:
dtype: float16 # precision to save. I recommend float16
save_every: 50 # save every this many steps
# this will remove step counts more than this number
# allows you to save more often in case of a crash without filling up your drive
max_step_saves_to_keep: 2

# sampling config
sample:
Expand All @@ -92,21 +99,22 @@ config:
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
# slide are good tests. will inherit sample.network_multiplier if not set
# --n [string] # negative prompt, will inherit sample.neg if not set

# Only 75 tokens allowed currently
prompts: # our example is an animal slider, neg: dog, pos: cat
- "a golden retriever --m -5"
- "a golden retriever --m -3"
- "a golden retriever --m 3"
- "a golden retriever --m 5"
- "calico cat --m -5"
- "calico cat --m -3"
- "calico cat --m 3"
- "calico cat --m 5"
- "an elephant --m -5"
- "an elephant --m -3"
- "an elephant --m 3"
- "an elephant --m 5"
# I like to do a wide positive and negative spread so I can see a good range and stop
# early if the network is braking down
prompts:
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
- "a golden retriever sitting on a leather couch, --m -5"
- "a golden retriever sitting on a leather couch --m -3"
- "a golden retriever sitting on a leather couch --m 3"
- "a golden retriever sitting on a leather couch --m 5"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
# negative prompt used on all prompts above as default if they don't have one
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
# seed for sampling. 42 is the answer for everything
Expand Down Expand Up @@ -135,11 +143,16 @@ config:
# resolutions to train on. [ width, height ]. This is less important for sliders
# as we are not teaching the model anything it doesn't already know
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
# and [ 1024, 1024 ] for sd_xl
# you can do as many as you want here
resolutions:
- [ 512, 512 ]
# - [ 512, 768 ]
# - [ 768, 768 ]
# slider training uses 4 combined steps for a single round. This will do it in one gradient
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
# since we break down batches for gradient accumulation now. so just leave it on.
batch_full_slide: true
# These are the concepts to train on. You can do as many as you want here,
# but they can conflict outweigh each other. Other than experimenting, I recommend
# just doing one for good results
Expand All @@ -150,41 +163,54 @@ config:
# a keyword necessarily but what the model understands the concept to represent.
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc
# it is the models base general understanding of the concept and everything it represents
- target_class: "animal"
# you can leave it blank to affect everything. In this example, we are adjusting
# detail, so we will leave it blank to affect everything
- target_class: ""
# positive is the prompt for the positive side of the slider.
# It is the concept that will be excited and amplified in the model when we slide the slider
# to the positive side and forgotten / inverted when we slide
# the slider to the negative side. It is generally best to include the target_class in
# the prompt. You want it to be the extreme of what you want to train on. For example,
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
# as the prompt. Not just "fat person"
positive: "cat"
# max 75 tokens for now
positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
# negative is the prompt for the negative side of the slider and works the same as positive
# it does not necessarily work the same as a negative prompt when generating images
negative: "dog"
# these need to be polar opposites.
# max 76 tokens for now
negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
# the loss for this target is multiplied by this number.
# if you are doing more than one target it may be good to set less important ones
# to a lower number like 0.1 so they dont outweigh the primary target
# to a lower number like 0.1 so they don't outweigh the primary target
weight: 1.0

# anchors are prompts that wer try to hold on to while training the slider
# you want these to generate an image very similar to the target_class
# without directly overlapping it. For example, if you are training on a person smiling,
# you would use "a person with a face mask" as an anchor. It is a person, the image is the same
# regardless if they are smiling or not
anchors:
# only positive prompt for now
- prompt: "a woman"
neg_prompt: "animal"
# the multiplier applied to the LoRA when this is run.
# higher will give it more weight but also help keep the lora from collapsing
multiplier: 8.0
- prompt: "a man"
neg_prompt: "animal"
multiplier: 8.0
- prompt: "a person"
neg_prompt: "animal"
multiplier: 8.0

# anchors are prompts that we will try to hold on to while training the slider
# these are NOT necessary and can prevent the slider from converging if not done right
# leave them off if you are having issues, but they can help lock the network
# on certain concepts to help prevent catastrophic forgetting
# you want these to generate an image that is not your target_class, but close to it
# is fine as long as it does not directly overlap it.
# For example, if you are training on a person smiling,
# you could use "a person with a face mask" as an anchor. It is a person, the image is the same
# regardless if they are smiling or not, however, the closer the concept is to the target_class
# the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
# for close concepts, you want to be closer to 0.1 or 0.2
# these will slow down training. I am leaving them off for the demo

# anchors:
# - prompt: "a woman"
# neg_prompt: "animal"
# # the multiplier applied to the LoRA when this is run.
# # higher will give it more weight but also help keep the lora from collapsing
# multiplier: 1.0
# - prompt: "a man"
# neg_prompt: "animal"
# multiplier: 1.0
# - prompt: "a person"
# neg_prompt: "animal"
# multiplier: 1.0

# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
Expand Down
2 changes: 1 addition & 1 deletion info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.3"
v["version"] = "0.0.4"

software_meta = v
9 changes: 9 additions & 0 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def run(self):
unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# if isinstance(text_encoder, list):
# for te in text_encoder:
# te.enable_gradient_checkpointing()
# else:
# text_encoder.enable_gradient_checkpointing()

unet.to(self.device_torch, dtype=dtype)
unet.requires_grad_(False)
unet.eval()
Expand Down Expand Up @@ -281,6 +287,9 @@ def run(self):
default_lr=self.train_config.lr
)

if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()

latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
Expand Down
Loading

0 comments on commit 8c90fa8

Please sign in to comment.