Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebasin fix #37

Merged
merged 2 commits into from
Oct 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions scripts/model_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _all_blocks(isxl=False):
blocks = [ base_prefix ]
for i in range(0, BLOCKLEN):
blocks.append(f"input_blocks.{i}.")
blocks.append("middle_block.1.")
blocks.append("middle_block.")
for i in range(0, BLOCKLEN):
blocks.append(f"output_blocks.{i}.")
return blocks
Expand Down Expand Up @@ -478,8 +478,15 @@ def mm_list_models():
def get_rebasin_perms(mbws, isxl):
"""all blocks permutations of selected blocks"""

selected = get_selected_blocks(mbws, isxl)
all_blocks = _all_blocks(isxl)
if True in mbws or False in mbws: # already have selected
_selected = mbws
all_blocks = _all_blocks(isxl)
selected = []
for i, v in enumerate(_selected):
if v:
selected.append(all_blocks[i])
else:
selected = get_selected_blocks(mbws, isxl)

if len(selected) > 0:
axes = []
Expand All @@ -496,7 +503,24 @@ def get_rebasin_perms(mbws, isxl):
return None


def get_rebasin_groups(mbws, isxl):
def get_rebasin_axes(mbws, isxl):
"""select all blocks correspond their permutation groups"""

perms = get_rebasin_perms(mbws, isxl)
if perms is None:
return None

# get all axes and corresponde blocks
blocks = []
axes = []
for perm in perms:
axes += [axes[0] for axes in permutation_spec.perm_to_axes[perm]]
axes = list(set(axes))

return axes


def _get_rebasin_blocks(mbws, isxl):
"""select all blocks correspond their permutation groups"""

perms = get_rebasin_perms(mbws, isxl)
Expand All @@ -514,16 +538,15 @@ def get_rebasin_groups(mbws, isxl):
MAXLEN = 26 - (0 if not isxl else 6)
BLOCKLEN = 12 - (0 if not isxl else 3)
BLOCKOFFSET = 13 if not isxl else 10
new_selected = [False]*MAXLEN
selected = [False]*MAXLEN
BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL

all_blocks = _all_blocks(isxl)
blocknames = []
for j, block in enumerate(all_blocks):
if any(block in axe for axe in axes):
blocknames.append(BLOCKIDS[j])
selected[j] = True

return blocknames
return selected


def get_device():
Expand Down Expand Up @@ -1939,6 +1962,22 @@ def load_state_dict(checkpoint_info):
mm_selected_all[k] = mm_selected_all[k] or elemental_selected[k]

all_blocks = _all_blocks(isxl)
BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL

# get all blocks affected by same perm groups by rebasin merge
if not isxl and "Rebasin" in mm_calcmodes:
print("check affected permutation blocks by rebasin merge...")
jj = 0
while True:
xx_selected_all = _get_rebasin_blocks(mm_selected_all, isxl)
changed = [BLOCKIDS[i] for i, v in enumerate(mm_selected_all) if v != xx_selected_all[i]]
if len(changed) > 0:
print(f" - [{jj+1}] {changed} block{'s' if len(changed) > 1 else ''} added")
mm_selected_all = xx_selected_all
jj += 1
else:
break

for k in range(max_blocks):
if mm_selected_all[k]:
selected_blocks.append(all_blocks[k])
Expand Down