Skip to content

Commit

Permalink
Merge pull request #130 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.1
  • Loading branch information
cloneofsimo authored Jan 9, 2023
2 parents da212b2 + 25aeab4 commit e19f6ae
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 28 deletions.
39 changes: 28 additions & 11 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def perform_tuning(
placeholder_token_ids,
placeholder_tokens,
save_path,
lr_scheduler_lora,
):

progress_bar = tqdm(range(num_steps))
Expand All @@ -430,6 +431,8 @@ def perform_tuning(

for epoch in range(math.ceil(num_steps / len(dataloader))):
for batch in dataloader:
lr_scheduler_lora.step()

optimizer.zero_grad()

loss = loss_step(
Expand All @@ -447,6 +450,11 @@ def perform_tuning(
)
optimizer.step()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler_lora.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)

global_step += 1

Expand Down Expand Up @@ -504,27 +512,29 @@ def train(
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
max_train_steps_tuning: int = 10000,
max_train_steps_ti: int = 2000,
save_steps: int = 500,
gradient_accumulation_steps: int = 1,
max_train_steps_tuning: int = 1000,
max_train_steps_ti: int = 1000,
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
mixed_precision="fp16",
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
clip_ti_decay: bool = True,
learning_rate_unet: float = 1e-5,
learning_rate_unet: float = 1e-4,
learning_rate_text: float = 1e-5,
learning_rate_ti: float = 5e-4,
continue_inversion: bool = True,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_scheduler: str = "linear",
lr_warmup_steps: int = 0,
weight_decay_ti: float = 0.01,
weight_decay_lora: float = 0.01,
lr_scheduler_lora: str = "linear",
lr_warmup_steps_lora: int = 0,
weight_decay_ti: float = 0.00,
weight_decay_lora: float = 0.001,
use_8bit_adam: bool = False,
device="cuda:0",
extra_args: Optional[dict] = None,
Expand Down Expand Up @@ -553,7 +563,7 @@ def train(
placeholder_tokens = placeholder_tokens.split("|")
if initializer_tokens is None:
print("PTI : Initializer Token not give, random inits")
initializer_tokens = ["<rand-0.036>"] * len(placeholder_tokens)
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
else:
initializer_tokens = initializer_tokens.split("|")

Expand Down Expand Up @@ -588,8 +598,7 @@ def train(
)

if gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()

if scale_lr:
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
Expand Down Expand Up @@ -734,6 +743,13 @@ def train(

train_dataset.blur_amount = 70

lr_scheduler_lora = get_scheduler(
lr_scheduler_lora,
optimizer=lora_optimizers,
num_warmup_steps=lr_warmup_steps_lora,
num_training_steps=max_train_steps_tuning,
)

perform_tuning(
unet,
vae,
Expand All @@ -746,6 +762,7 @@ def train(
placeholder_tokens=placeholder_tokens,
placeholder_token_ids=placeholder_token_ids,
save_path=output_dir,
lr_scheduler_lora=lr_scheduler_lora,
)


Expand Down
26 changes: 13 additions & 13 deletions lora_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,15 @@
]


def image_grid(_imgs, rows = None, cols = None):
def image_grid(_imgs, rows=None, cols=None):

if rows is None and cols is None:
rows = cols = math.ceil(len(_imgs) ** 0.5)

if rows is None:
rows = math.ceil(len(_imgs) / cols)
if cols is None:
cols = math.ceil(len(_imgs) / rows)


w, h = _imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
Expand Down Expand Up @@ -176,25 +175,23 @@ def visualize_progress(
text_sclae=1.0,
num_inference_steps=50,
guidance_scale=5.0,
offset : int = 0,
limit : int = 10,
seed : int = 0
offset: int = 0,
limit: int = 10,
seed: int = 0,
):


imgs = []
if isinstance(path_alls, str):
alls = list(set(glob.glob(path_alls)))

alls.sort(key=os.path.getmtime)
else:
alls = path_alls

pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)


print(f"Found {len(alls)} checkpoints")
for path in alls[offset:limit]:
print(path)
Expand All @@ -207,8 +204,11 @@ def visualize_progress(
tune_lora_scale(pipe.text_encoder, text_sclae)

torch.manual_seed(seed)
image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
image = pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
imgs.append(image)

return imgs

3 changes: 3 additions & 0 deletions training_scripts/multivector_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ lora_pti \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--scale_lr \
--learning_rate_unet=1e-4 \
--learning_rate_text=1e-5 \
--learning_rate_ti=5e-4 \
--color_jitter \
--lr_scheduler="linear" \
--lr_warmup_steps=0 \
--lr_scheduler_lora="linear" \
--lr_warmup_steps_lora=100 \
--placeholder_tokens="<s1>|<s2>" \
--use_template="style"\
--save_steps=100 \
Expand Down
13 changes: 9 additions & 4 deletions training_scripts/train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,25 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None

img_transforms = []

img_transforms = []

if resize:
img_transforms.append(transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR))
img_transforms.append(
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
)
if center_crop:
img_transforms.append(transforms.CenterCrop(size))
if color_jitter:
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
if h_flip:
img_transforms.append(transforms.RandomHorizontalFlip())

self.image_transforms = transforms.Compose([*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
self.image_transforms = transforms.Compose(
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)

def __len__(self):
return self._length
Expand Down

0 comments on commit e19f6ae

Please sign in to comment.