Skip to content

Commit

Permalink
manager order
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Feb 9, 2023
1 parent 0ddce0b commit 7be7bed
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
Binary file added example_loras/and.safetensors
Binary file not shown.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .dataset import *
from .utils import *
from .preprocess_files import *
from .lora_manager import *
10 changes: 7 additions & 3 deletions lora_diffusion/lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import torch
from safetensors import safe_open
from diffusers import StableDiffusionPipeline
Expand Down Expand Up @@ -50,12 +51,15 @@ def lora_join(lora_safetenors: list):
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
if total_metadata.get(token, None) is not None:
del total_metadata[token]

total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"

print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

if total_metadata.get(token, None) is not None:
del total_metadata[token]

token_size_list.append(len(tokens))

return total_tensor, total_metadata, ranklist, token_size_list
Expand All @@ -77,7 +81,7 @@ def get_tensor(self, key):


class LoRAManager:
def __init__(self, lora_paths_list, pipe: StableDiffusionPipeline):
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):

self.lora_paths_list = lora_paths_list
self.pipe = pipe
Expand Down

0 comments on commit 7be7bed

Please sign in to comment.