Skip to content

Commit

Permalink
del token during join
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Feb 4, 2023
1 parent ef2b17c commit 65c9200
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def lora_join(lora_safetenors: list):
total_metadata = {}
total_tensor = {}
total_rank = 0
ranklist = []
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
Expand All @@ -34,6 +35,7 @@ def lora_join(lora_safetenors: list):
assert len(set(rankset)) == 1, "Rank should be the same per model"
total_rank += rankset[0]
total_metadata.update(_metadata)
ranklist.append(rankset[0])

tensorkeys = set()
for safelora in lora_safetenors:
Expand All @@ -55,16 +57,19 @@ def lora_join(lora_safetenors: list):
total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)

token_size_list = []
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
del total_metadata[token]
if total_metadata.get(token, None) is not None:
del total_metadata[token]
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

return total_tensor, total_metadata
token_size_list.append(len(tokens))

return total_tensor, total_metadata, ranklist, token_size_list


def add(
Expand Down Expand Up @@ -221,7 +226,7 @@ def add(
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")

total_tensor, total_metadata = lora_join([safeloras_1, safeloras_2])
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
save_file(total_tensor, output_path, total_metadata)

else:
Expand Down

0 comments on commit 65c9200

Please sign in to comment.