Skip to content

Commit

Permalink
Merge pull request #171 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.6
  • Loading branch information
cloneofsimo authored Feb 9, 2023
2 parents 848db91 + a9354e6 commit 99ba84b
Show file tree
Hide file tree
Showing 16 changed files with 801 additions and 56 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
- Fine-tune Stable diffusion models twice as fast than dreambooth method, by Low-rank Adaptation
- Get insanely small end result (1MB ~ 6MB), easy to share and download.
- Compatible with `diffusers`
- Support for inpainting
- Sometimes _even better performance_ than full fine-tuning (but left as future work for extensive comparisons)
- Merge checkpoints + Build recipes by merging LoRAs together
- Pipeline to fine-tune CLIP + Unet + token to gain better results.
Expand All @@ -50,6 +51,10 @@

# UPDATES & Notes

### 2023/02/06

- Support for training inpainting on LoRA PTI. Use flag `--train-inpainting` with a inpainting stable diffusion base model (see `inpainting_example.sh`).

### 2023/02/01

- LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`.
Expand Down
Binary file added contents/inpainting_base_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/inpainting_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/lora_pti_inpainting.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/lora_pti_inpainting_example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_loras/and.safetensors
Binary file not shown.
Binary file added example_loras/lora_krk_inpainting.safetensors
Binary file not shown.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .dataset import *
from .utils import *
from .preprocess_files import *
from .lora_manager import *
50 changes: 2 additions & 48 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
collapse_lora,
monkeypatch_remove_lora,
)
from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt


Expand All @@ -20,53 +21,6 @@ def _text_lora_path(path: str) -> str:
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
total_metadata = {}
total_tensor = {}
total_rank = 0
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))

assert len(set(rankset)) == 1, "Rank should be the same per model"
total_rank += rankset[0]
total_metadata.update(_metadata)

tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())

for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]

is_down = keys.endswith("down")

if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank

total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)

for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
del total_metadata[token]
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

return total_tensor, total_metadata


def add(
path_1: str,
path_2: str,
Expand Down Expand Up @@ -221,7 +175,7 @@ def add(
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")

total_tensor, total_metadata = lora_join([safeloras_1, safeloras_2])
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
save_file(total_tensor, output_path, total_metadata)

else:
Expand Down
87 changes: 81 additions & 6 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,60 @@ def collate_fn(examples):

return train_dataloader

def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
mask_values = [example["instance_masks"] for example in examples]
masked_image_values = [example["instance_masked_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
mask_values += [example["class_masks"] for example in examples]
masked_image_values += [example["class_masked_images"] for example in examples]

pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()

input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"mask_values": mask_values,
"masked_image_values": masked_image_values
}

if examples[0].get("mask", None) is not None:
batch["mask"] = torch.stack([example["mask"] for example in examples])

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

return train_dataloader

def loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
train_inpainting=False,
t_mutliplier=1.0,
mixed_precision=False,
mask_temperature=1.0,
Expand All @@ -186,6 +233,16 @@ def loss_step(
).latent_dist.sample()
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1/8
)

noise = torch.randn_like(latents)
bsz = latents.shape[0]

Expand All @@ -199,21 +256,26 @@ def loss_step(

noisy_latents = scheduler.add_noise(latents, noise, timesteps)

if train_inpainting:
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
else:
latent_model_input = noisy_latents

if mixed_precision:
with torch.cuda.amp.autocast():

encoder_hidden_states = text_encoder(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
else:

encoder_hidden_states = text_encoder(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample

if scheduler.config.prediction_type == "epsilon":
target = noise
Expand Down Expand Up @@ -270,6 +332,7 @@ def train_inversion(
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
train_inpainting: bool = False,
mixed_precision: bool = False,
clip_ti_decay: bool = True,
):
Expand Down Expand Up @@ -302,6 +365,7 @@ def train_inversion(
vae,
text_encoder,
scheduler,
train_inpainting=train_inpainting,
mixed_precision=mixed_precision,
)
/ accum_iter
Expand Down Expand Up @@ -384,7 +448,7 @@ def train_inversion(
# open all images in test_image_path
images = []
for file in os.listdir(test_image_path):
if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
images.append(
Image.open(os.path.join(test_image_path, file))
)
Expand Down Expand Up @@ -429,6 +493,7 @@ def perform_tuning(
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
train_inpainting: bool = False,
):

progress_bar = tqdm(range(num_steps))
Expand Down Expand Up @@ -457,6 +522,7 @@ def perform_tuning(
vae,
text_encoder,
scheduler,
train_inpainting=train_inpainting,
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
Expand Down Expand Up @@ -565,6 +631,7 @@ def train(
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
train_inpainting: bool = False,
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
Expand Down Expand Up @@ -716,13 +783,19 @@ def train(
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
use_mask_captioned_data=use_mask_captioned_data,
train_inpainting=train_inpainting,
)

train_dataset.blur_amount = 200

train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
if train_inpainting:
train_dataloader = inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
else:
train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)

index_no_updates = torch.arange(len(tokenizer)) != -1

Expand Down Expand Up @@ -776,6 +849,7 @@ def train(
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
train_inpainting=train_inpainting,
mixed_precision=False,
tokenizer=tokenizer,
clip_ti_decay=clip_ti_decay,
Expand Down Expand Up @@ -883,6 +957,7 @@ def train(
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
train_inpainting=train_inpainting,
)


Expand Down
33 changes: 32 additions & 1 deletion lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cv2
import numpy as np
from PIL import Image, ImageFilter
from torch import zeros_like
from torch.utils.data import Dataset
from torchvision import transforms
import glob
Expand Down Expand Up @@ -85,6 +86,29 @@ def _shuffle(lis):
return random.sample(lis, len(lis))


def _get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128):
holes = []
for _n in range(random.randint(min_holes, max_holes)):
hole_height = random.randint(min_height, max_height)
hole_width = random.randint(min_width, max_width)
y1 = random.randint(0, height - hole_height)
x1 = random.randint(0, width - hole_width)
y2 = y1 + hole_height
x2 = x1 + hole_width
holes.append((x1, y1, x2, y2))
return holes


def _generate_random_mask(image):
mask = zeros_like(image[:1])
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
for (x1, y1, x2, y2) in holes:
mask[:, y1:y2, x1:x2] = 1.
if random.uniform(0, 1) < 0.25:
mask.fill_(1.)
masked_image = image * (mask < 0.5)
return mask, masked_image

class PivotalTuningDatasetCapation(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
Expand All @@ -106,11 +130,13 @@ def __init__(
resize=True,
use_mask_captioned_data=False,
use_face_segmentation_condition=False,
train_inpainting=False,
blur_amount: int = 70,
):
self.size = size
self.tokenizer = tokenizer
self.resize = resize
self.train_inpainting = train_inpainting

instance_data_root = Path(instance_data_root)
if not instance_data_root.exists():
Expand Down Expand Up @@ -247,6 +273,9 @@ def __getitem__(self, index):
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

if self.train_inpainting:
example["instance_masks"], example["instance_masked_images"] = _generate_random_mask(example["instance_images"])

if self.use_template:
assert self.token_map is not None
input_tok = list(self.token_map.values())[0]
Expand All @@ -267,7 +296,7 @@ def __getitem__(self, index):
Image.open(self.mask_path[index % self.num_instance_images])
)
* 0.5
+ 1
+ 0.5
)

if self.h_flip and random.random() > 0.5:
Expand All @@ -291,6 +320,8 @@ def __getitem__(self, index):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.train_inpainting:
example["class_masks"], example["class_masked_images"] = _generate_random_mask(example["class_images"])
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
Expand Down
2 changes: 2 additions & 0 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ def patch_pipe(
unet_path = maybe_unet_path[:-6] + ".pt"
elif maybe_unet_path.endswith(".text_encoder.pt"):
unet_path = maybe_unet_path[:-16] + ".pt"
else:
unet_path = maybe_unet_path

ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)
Expand Down
Loading

0 comments on commit 99ba84b

Please sign in to comment.