Skip to content

Commit

Permalink
Added support for training on primary gpu with low_vram flag. Updated…
Browse files Browse the repository at this point in the history
… example script to remove creepy horse sample at that seed
  • Loading branch information
jaretburkett committed Aug 11, 2024
1 parent fa02e77 commit ec1ea7a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ This is my research repo. I do a lot of experiments in it and it is possible tha
If something breaks, checkout an earlier commit. This repo can train a lot of things, and it is
hard to keep up with all of them.

## Support my work

My work would not be possible without the amazing support of [Glif](https://glif.app/).

## Installation

Requirements:
Expand Down Expand Up @@ -43,16 +47,21 @@ pip install -r requirements.txt

### WIP. I am updating docs and optimizing as fast as I can. If there are bugs open a ticket. Not knowing how to get it to work is NOT a bug. Be paitient as I continue to develop it.

Training currently only works with FLUX.1-dev. Which means anything you train will inherit the
non commercial license. It is also a gated model, so you need to accept the license on HF before using it.
Otherwise, this will fail. Here are the required steps to setup a license.

### Requirements
You currently need a dedicated GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
your monitors, it will probably not fit as that takes up some ram. I may be able to get this lower, but for now,
It won't work. It may not work on Windows, I have only tested on linux for now. This is still extremely experimental
You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
but there are some reports of a bug when running on windows natively.
I have only tested on linux for now. This is still extremely experimental
and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.

### Model License

Training currently only works with FLUX.1-dev. Which means anything you train will inherit the
non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
Otherwise, this will fail. Here are the required steps to setup a license.

1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
2. Make a file named `.env` in the root on this folder
3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
Expand Down
11 changes: 7 additions & 4 deletions config/examples/train_lora_flux_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ config:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
- folder_path: "/mnt/Datasets/1920s_illustrations"
# - folder_path: "/path/to/images/folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
num_workers: 0
train:
batch_size: 1
steps: 4000 # total number of steps to train
steps: 4000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # probably won't work with flux
Expand All @@ -43,6 +43,8 @@ config:
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 4e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true

# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:
Expand All @@ -56,6 +58,7 @@ config:
name_or_path: "black-forest-labs/FLUX.1-dev"
is_flux: true
quantize: true # run 8bit mixed precision
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
Expand All @@ -66,7 +69,7 @@ config:
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse in a night club dancing, fish eye lens, smoke machine, lazer lights, holding a martini, large group"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
Expand Down
1 change: 1 addition & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(self, **kwargs):

# only for flux for now
self.quantize = kwargs.get("quantize", False)
self.low_vram = kwargs.get("low_vram", False)
pass


Expand Down
10 changes: 7 additions & 3 deletions toolkit/stable_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance

from optimum.quanto import freeze, qfloat8, quantize, QTensor
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4

# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
self.is_flow_matching = True

self.quantize_device = quantize_device if quantize_device is not None else self.device
self.low_vram = self.model_config.low_vram

def load_model(self):
if self.is_loaded:
Expand Down Expand Up @@ -472,7 +473,9 @@ def load_model(self):
# low_cpu_mem_usage=False,
# device_map=None
)
transformer.to(torch.device(self.quantize_device), dtype=dtype)
if not self.low_vram:
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()

if self.model_config.lora_path is not None:
Expand All @@ -493,8 +496,9 @@ def load_model(self):
pipe.unload_lora_weights()

if self.model_config.quantize:
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
quantize(transformer, weights=quantization_type)
freeze(transformer)
transformer.to(self.device_torch)
else:
Expand Down

0 comments on commit ec1ea7a

Please sign in to comment.