diff --git a/example_loras/and.safetensors b/example_loras/and.safetensors new file mode 100644 index 0000000..c615fc1 Binary files /dev/null and b/example_loras/and.safetensors differ diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 984b4f6..286e4fe 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -2,3 +2,4 @@ from .dataset import * from .utils import * from .preprocess_files import * +from .lora_manager import * diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index ae33e3d..fc7f7e4 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -12,6 +12,7 @@ collapse_lora, monkeypatch_remove_lora, ) +from .lora_manager import lora_join from .to_ckpt_v2 import convert_to_ckpt @@ -20,58 +21,6 @@ def _text_lora_path(path: str) -> str: return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) -def lora_join(lora_safetenors: list): - metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] - total_metadata = {} - total_tensor = {} - total_rank = 0 - ranklist = [] - for _metadata in metadatas: - rankset = [] - for k, v in _metadata.items(): - if k.endswith("rank"): - rankset.append(int(v)) - - assert len(set(rankset)) == 1, "Rank should be the same per model" - total_rank += rankset[0] - total_metadata.update(_metadata) - ranklist.append(rankset[0]) - - tensorkeys = set() - for safelora in lora_safetenors: - tensorkeys.update(safelora.keys()) - - for keys in tensorkeys: - if keys.startswith("text_encoder") or keys.startswith("unet"): - tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] - - is_down = keys.endswith("down") - - if is_down: - _tensor = torch.cat(tensorset, dim=0) - assert _tensor.shape[0] == total_rank - else: - _tensor = torch.cat(tensorset, dim=1) - assert _tensor.shape[1] == total_rank - - total_tensor[keys] = _tensor - keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" - total_metadata[keys_rank] = str(total_rank) - token_size_list = [] - for idx, safelora in enumerate(lora_safetenors): - tokens = [k for k, v in safelora.metadata().items() if v == ""] - for jdx, token in enumerate(sorted(tokens)): - if total_metadata.get(token, None) is not None: - del total_metadata[token] - total_tensor[f""] = safelora.get_tensor(token) - total_metadata[f""] = "" - print(f"Embedding {token} replaced to ") - - token_size_list.append(len(tokens)) - - return total_tensor, total_metadata, ranklist, token_size_list - - def add( path_1: str, path_2: str, diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py new file mode 100644 index 0000000..2ef9608 --- /dev/null +++ b/lora_diffusion/lora_manager.py @@ -0,0 +1,134 @@ +from typing import List +import torch +from safetensors import safe_open +from diffusers import StableDiffusionPipeline +from .lora import ( + monkeypatch_or_replace_safeloras, + apply_learned_embed_in_clip, + set_lora_diag, + parse_safeloras_embeds, +) + + +def lora_join(lora_safetenors: list): + metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] + total_metadata = {} + total_tensor = {} + total_rank = 0 + ranklist = [] + for _metadata in metadatas: + rankset = [] + for k, v in _metadata.items(): + if k.endswith("rank"): + rankset.append(int(v)) + + assert len(set(rankset)) == 1, "Rank should be the same per model" + total_rank += rankset[0] + total_metadata.update(_metadata) + ranklist.append(rankset[0]) + + tensorkeys = set() + for safelora in lora_safetenors: + tensorkeys.update(safelora.keys()) + + for keys in tensorkeys: + if keys.startswith("text_encoder") or keys.startswith("unet"): + tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] + + is_down = keys.endswith("down") + + if is_down: + _tensor = torch.cat(tensorset, dim=0) + assert _tensor.shape[0] == total_rank + else: + _tensor = torch.cat(tensorset, dim=1) + assert _tensor.shape[1] == total_rank + + total_tensor[keys] = _tensor + keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" + total_metadata[keys_rank] = str(total_rank) + token_size_list = [] + for idx, safelora in enumerate(lora_safetenors): + tokens = [k for k, v in safelora.metadata().items() if v == ""] + for jdx, token in enumerate(sorted(tokens)): + + total_tensor[f""] = safelora.get_tensor(token) + total_metadata[f""] = "" + + print(f"Embedding {token} replaced to ") + + if total_metadata.get(token, None) is not None: + del total_metadata[token] + + token_size_list.append(len(tokens)) + + return total_tensor, total_metadata, ranklist, token_size_list + + +class DummySafeTensorObject: + def __init__(self, tensor: dict, metadata): + self.tensor = tensor + self._metadata = metadata + + def keys(self): + return self.tensor.keys() + + def metadata(self): + return self._metadata + + def get_tensor(self, key): + return self.tensor[key] + + +class LoRAManager: + def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline): + + self.lora_paths_list = lora_paths_list + self.pipe = pipe + self._setup() + + def _setup(self): + + self._lora_safetenors = [ + safe_open(path, framework="pt", device="cpu") + for path in self.lora_paths_list + ] + + ( + total_tensor, + total_metadata, + self.ranklist, + self.token_size_list, + ) = lora_join(self._lora_safetenors) + + self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata) + + monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora) + tok_dict = parse_safeloras_embeds(self.total_safelora) + + apply_learned_embed_in_clip( + tok_dict, + self.pipe.text_encoder, + self.pipe.tokenizer, + token=None, + idempotent=True, + ) + + def tune(self, scales): + + diags = [] + for scale, rank in zip(scales, self.ranklist): + diags = diags + [scale] * rank + + set_lora_diag(self.pipe.unet, torch.tensor(diags)) + + def prompt(self, prompt): + if prompt is not None: + for idx, tok_size in enumerate(self.token_size_list): + prompt = prompt.replace( + f"<{idx + 1}>", + "".join([f"" for jdx in range(tok_size)]), + ) + # TODO : Rescale LoRA + Text inputs based on prompt scale params + + return prompt