Skip to content

Commit

Permalink
Merge pull request #161 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.3
  • Loading branch information
cloneofsimo authored Jan 31, 2023
2 parents 1707928 + 2028b9e commit 2e1ffae
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 39 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@

# Web Demo

- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ysharma/Low-rank-Adaptation)
- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/lora-library/LoRA-DreamBooth-Training-UI)

- Easy [colab running example](https://colab.research.google.com/drive/1iSFDpRBKEWr2HLlz243rbym3J2X95kcy?usp=sharing) of Dreambooth by @pedrogengo

# UPDATES & Notes

### 2022/02/01

- LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`.

### 2022/01/29

- Dataset pipelines
Expand Down Expand Up @@ -106,7 +110,7 @@ First, there is LoRA applied to Dreambooth. The idea is to use prior-preservatio

2. [Textual Inversion](https://arxiv.org/abs/2208.01618)

Second, there is Textual inversion. There is no room to apply LoRA here, but it is worth mensioning. The idea is to instantiate new token, and learn the token embedding via gradient descent. This is a very powerful method, and it is worth trying out if your use case is not focused on fidelity but rather on inverting conceptual ideas.
Second, there is Textual inversion. There is no room to apply LoRA here, but it is worth mentioning. The idea is to instantiate new token, and learn the token embedding via gradient descent. This is a very powerful method, and it is worth trying out if your use case is not focused on fidelity but rather on inverting conceptual ideas.

3. [Pivotal Tuning](https://arxiv.org/abs/2106.05744)

Expand Down
Binary file added example_loras/concat_disney_krk.safetensors
Binary file not shown.
58 changes: 58 additions & 0 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,53 @@ def _text_lora_path(path: str) -> str:
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
total_metadata = {}
total_tensor = {}
total_rank = 0
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))

assert len(set(rankset)) == 1, "Rank should be the same per model"
total_rank += rankset[0]
total_metadata.update(_metadata)

tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())

for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]

is_down = keys.endswith("down")

if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank

total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)

for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
del total_metadata[token]
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

return total_tensor, total_metadata


def add(
path_1: str,
path_2: str,
Expand Down Expand Up @@ -165,6 +212,17 @@ def add(
print(
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
)
elif mode == "ljl":
print("Using Join mode : alpha will not have an effect here.")
assert path_1.endswith(".safetensors") and path_2.endswith(
".safetensors"
), "Only .safetensors files are supported"

safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")

total_tensor, total_metadata = lora_join([safeloras_1, safeloras_2])
save_file(total_tensor, output_path, total_metadata)

else:
print("Unknown mode", mode)
Expand Down
12 changes: 11 additions & 1 deletion lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def train(
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
lora_dropout_p: float = 0.0,
lora_scale: float = 1.0,
use_extended_lora: bool = False,
clip_ti_decay: bool = True,
learning_rate_unet: float = 1e-4,
Expand Down Expand Up @@ -578,6 +580,10 @@ def train(
else:
placeholder_tokens = placeholder_tokens.split("|")

assert (
sorted(placeholder_tokens) == placeholder_tokens
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"

if initializer_tokens is None:
print("PTI : Initializer Tokens not given, doing random inits")
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
Expand Down Expand Up @@ -720,7 +726,11 @@ def train(
# Next perform Tuning with LoRA:
if not use_extended_lora:
unet_lora_params, _ = inject_trainable_lora(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
unet,
r=lora_rank,
target_replace_module=lora_unet_target_modules,
dropout_p=lora_dropout_p,
scale=lora_scale,
)
else:
print("PTI : USING EXTENDED UNET!!!")
Expand Down
5 changes: 4 additions & 1 deletion lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def __init__(
)

masks = face_mask_google_mediapipe(
[Image.open(f) for f in self.instance_images_path]
[
Image.open(f).convert("RGB")
for f in self.instance_images_path
]
)
for idx, mask in enumerate(masks):
mask.save(f"{instance_data_root}/{idx}.mask.png")
Expand Down
109 changes: 92 additions & 17 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,45 @@ def safe_save(


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1):
def __init__(
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
):
super().__init__()

if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)

self.r = r
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0
self.scale = scale
self.selector = nn.Identity()

nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return (
self.linear(input)
+ self.lora_up(self.dropout(self.lora_down(input))) * self.scale
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
* self.scale
)

def realize_as_lora(self):
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data

def set_selector_from_diag(self, diag: torch.Tensor):
# diag is a 1D tensor of size (r,)
assert diag.shape == (self.r,)
self.selector = nn.Linear(self.r, self.r, bias=False)
self.selector.weight.data = torch.diag(diag)
self.selector.weight.data = self.selector.weight.data.to(
self.lora_up.weight.device
).to(self.lora_up.weight.dtype)


class LoraInjectedConv2d(nn.Module):
def __init__(
Expand All @@ -67,13 +83,14 @@ def __init__(
bias: bool = True,
r: int = 4,
dropout_p: float = 0.1,
scale: float = 1.0,
):
super().__init__()
if r > min(in_channels, out_channels):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
)

self.r = r
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -104,17 +121,40 @@ def __init__(
padding=0,
bias=False,
)
self.scale = 1.0
self.selector = nn.Identity()
self.scale = scale

nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return (
self.conv(input)
+ self.lora_up(self.dropout(self.lora_down(input))) * self.scale
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
* self.scale
)

def realize_as_lora(self):
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data

def set_selector_from_diag(self, diag: torch.Tensor):
# diag is a 1D tensor of size (r,)
assert diag.shape == (self.r,)
self.selector = nn.Conv2d(
in_channels=self.r,
out_channels=self.r,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.selector.weight.data = torch.diag(diag)

# same device + dtype as lora_up
self.selector.weight.data = self.selector.weight.data.to(
self.lora_up.weight.device
).to(self.lora_up.weight.dtype)


UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}

Expand Down Expand Up @@ -217,6 +257,9 @@ def inject_trainable_lora(
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
r: int = 4,
loras=None, # path to lora .pt
verbose: bool = False,
dropout_p: float = 0.0,
scale: float = 1.0,
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -233,11 +276,16 @@ def inject_trainable_lora(
):
weight = _child_module.weight
bias = _child_module.bias
if verbose:
print("LoRA Injection : injecting lora into ", name)
print("LoRA Injection : weight shape", weight.shape)
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
r=r,
dropout_p=dropout_p,
scale=scale,
)
_tmp.linear.weight = weight
if bias is not None:
Expand Down Expand Up @@ -287,7 +335,7 @@ def inject_trainable_lora_extended(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
r=r,
)
_tmp.linear.weight = weight
if bias is not None:
Expand All @@ -304,7 +352,7 @@ def inject_trainable_lora_extended(
_child_module.dilation,
_child_module.groups,
_child_module.bias is not None,
r,
r=r,
)

_tmp.conv.weight = weight
Expand Down Expand Up @@ -349,6 +397,30 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
return loras


def extract_lora_as_tensor(
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
):

loras = []

for _m, _n, _child_module in _find_modules(
model,
target_replace_module,
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
):
up, down = _child_module.realize_as_lora()
if as_fp16:
up = up.to(torch.float16)
down = down.to(torch.float16)

loras.append((up, down))

if len(loras) == 0:
raise ValueError("No lora injected.")

return loras


def save_lora_weight(
model,
path="./lora.pt",
Expand Down Expand Up @@ -395,16 +467,13 @@ def save_safeloras_with_embeds(
metadata[name] = json.dumps(list(target_replace_module))

for i, (_up, _down) in enumerate(
extract_lora_ups_down(model, target_replace_module)
extract_lora_as_tensor(model, target_replace_module)
):
try:
rank = getattr(_down, "out_features")
except:
rank = getattr(_down, "out_channels")
rank = _down.shape[0]

metadata[f"{name}:{i}:rank"] = str(rank)
weights[f"{name}:{i}:up"] = _up.weight
weights[f"{name}:{i}:down"] = _down.weight
weights[f"{name}:{i}:up"] = _up
weights[f"{name}:{i}:down"] = _down

for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
Expand Down Expand Up @@ -811,6 +880,12 @@ def tune_lora_scale(model, alpha: float = 1.0):
_module.scale = alpha


def set_lora_diag(model, diag: torch.Tensor):
for _module in model.modules():
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
_module.set_selector_from_diag(diag)


def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ fire
wandb
safetensors
opencv-python
torchvision
torchvision
mediapipe
Loading

0 comments on commit 2e1ffae

Please sign in to comment.