Skip to content

Commit

Permalink
Merge pull request #164 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.4
  • Loading branch information
cloneofsimo authored Feb 1, 2023
2 parents 437cb62 + 82ba343 commit b1d8293
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 7 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@

# UPDATES & Notes

### 2022/02/01
### 2023/02/01

- LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`.

### 2022/01/29
### 2023/01/29

- Dataset pipelines
- LoRA Applied to Resnet as well, use `--use_extended_lora` to use it.
- SVD distillation now supports resnet-lora as well.
- Compvis format Conversion script now works with safetensors, and will for PTI it will return Textual inversion format as well, so you can use it in embeddings folder.
- 🥳🥳, LoRA is now officially integrated into the amazing Huggingface 🤗 `diffusers` library! Check out the [Blog](https://huggingface.co/blog/lora) and [examples](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora)! (NOTE : It is CURRENTLY DIFFERENT FILE FORMAT)

### 2022/01/09
### 2023/01/09

- Pivotal Tuning Inversion with extended latent
- Better textual inversion with Norm prior
Expand Down
72 changes: 69 additions & 3 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ def perform_tuning(
lora_unet_target_modules,
lora_clip_target_modules,
mask_temperature,
out_name: str,
tokenizer,
test_image_path: str,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
):

progress_bar = tqdm(range(num_steps))
Expand All @@ -434,6 +440,11 @@ def perform_tuning(
unet.train()
text_encoder.train()

if log_wandb:
preped_clip = prepare_clip_model_sets()

loss_sum = 0.0

for epoch in range(math.ceil(num_steps / len(dataloader))):
for batch in dataloader:
lr_scheduler_lora.step()
Expand All @@ -450,6 +461,8 @@ def perform_tuning(
mixed_precision=True,
mask_temperature=mask_temperature,
)
loss_sum += loss.detach().item()

loss.backward()
torch.nn.utils.clip_grad_norm_(
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
Expand Down Expand Up @@ -493,15 +506,59 @@ def perform_tuning(

print("LORA CLIP Moved", moved)

if log_wandb:
with torch.no_grad():
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
)

# open all images in test_image_path
images = []
for file in os.listdir(test_image_path):
if file.endswith(".png") or file.endswith(".jpg"):
images.append(
Image.open(os.path.join(test_image_path, file))
)

wandb.log({"loss": loss_sum / save_steps})
loss_sum = 0.0
wandb.log(
evaluate_pipe(
pipe,
target_images=images,
class_token=class_token,
learnt_token="".join(placeholder_tokens),
n_test=wandb_log_prompt_cnt,
n_step=50,
clip_model_sets=preped_clip,
)
)

if global_step >= num_steps:
return
break

save_all(
unet,
text_encoder,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
target_replace_module_text=lora_clip_target_modules,
target_replace_module_unet=lora_unet_target_modules,
)


def train(
instance_data_dir: str,
pretrained_model_name_or_path: str,
output_dir: str,
train_text_encoder: bool = False,
train_text_encoder: bool = True,
pretrained_vae_name_or_path: str = None,
revision: Optional[str] = None,
class_data_dir: Optional[str] = None,
Expand Down Expand Up @@ -555,7 +612,9 @@ def train(
wandb_log_prompt_cnt: int = 10,
wandb_project_name: str = "new_pti_project",
wandb_entity: str = "new_pti_entity",
proxy_token: str = "person",
enable_xformers_memory_efficient_attention: bool = False,
out_name: str = "final_lora",
):
torch.manual_seed(seed)

Expand All @@ -566,7 +625,6 @@ def train(
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
reinit=True,
config={
"lr": learning_rate_ti,
**(extra_args if extra_args is not None else {}),
},
)
Expand Down Expand Up @@ -594,6 +652,8 @@ def train(
placeholder_tokens
), "Unequal Initializer token for Placeholder tokens."

if proxy_token is not None:
class_token = proxy_token
class_token = "".join(initializer_tokens)

if placeholder_token_at_data is not None:
Expand Down Expand Up @@ -817,6 +877,12 @@ def train(
lora_unet_target_modules=lora_unet_target_modules,
lora_clip_target_modules=lora_clip_target_modules,
mask_temperature=mask_temperature,
tokenizer=tokenizer,
out_name=out_name,
test_image_path=instance_data_dir,
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
)


Expand Down
1 change: 1 addition & 0 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
transforms.ColorJitter(0.1, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.1.3",
version="0.1.4",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down

0 comments on commit b1d8293

Please sign in to comment.