Skip to content

Commit

Permalink
fix #36 bug
Browse files Browse the repository at this point in the history
 * get all same permutation blocks
  • Loading branch information
wkpark committed Oct 21, 2023
1 parent 887d73e commit bfb79fa
Showing 1 changed file with 46 additions and 7 deletions.
53 changes: 46 additions & 7 deletions scripts/model_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,15 @@ def mm_list_models():
def get_rebasin_perms(mbws, isxl):
"""all blocks permutations of selected blocks"""

selected = get_selected_blocks(mbws, isxl)
all_blocks = _all_blocks(isxl)
if True in mbws or False in mbws: # already have selected
_selected = mbws
all_blocks = _all_blocks(isxl)
selected = []
for i, v in enumerate(_selected):
if v:
selected.append(all_blocks[i])
else:
selected = get_selected_blocks(mbws, isxl)

if len(selected) > 0:
axes = []
Expand All @@ -496,7 +503,24 @@ def get_rebasin_perms(mbws, isxl):
return None


def get_rebasin_groups(mbws, isxl):
def get_rebasin_axes(mbws, isxl):
"""select all blocks correspond their permutation groups"""

perms = get_rebasin_perms(mbws, isxl)
if perms is None:
return None

# get all axes and corresponde blocks
blocks = []
axes = []
for perm in perms:
axes += [axes[0] for axes in permutation_spec.perm_to_axes[perm]]
axes = list(set(axes))

return axes


def _get_rebasin_blocks(mbws, isxl):
"""select all blocks correspond their permutation groups"""

perms = get_rebasin_perms(mbws, isxl)
Expand All @@ -514,16 +538,15 @@ def get_rebasin_groups(mbws, isxl):
MAXLEN = 26 - (0 if not isxl else 6)
BLOCKLEN = 12 - (0 if not isxl else 3)
BLOCKOFFSET = 13 if not isxl else 10
new_selected = [False]*MAXLEN
selected = [False]*MAXLEN
BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL

all_blocks = _all_blocks(isxl)
blocknames = []
for j, block in enumerate(all_blocks):
if any(block in axe for axe in axes):
blocknames.append(BLOCKIDS[j])
selected[j] = True

return blocknames
return selected


def get_device():
Expand Down Expand Up @@ -1939,6 +1962,22 @@ def load_state_dict(checkpoint_info):
mm_selected_all[k] = mm_selected_all[k] or elemental_selected[k]

all_blocks = _all_blocks(isxl)
BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL

# get all blocks affected by same perm groups by rebasin merge
if not isxl and "Rebasin" in mm_calcmodes:
print("check affected permutation blocks by rebasin merge...")
jj = 0
while True:
xx_selected_all = _get_rebasin_blocks(mm_selected_all, isxl)
changed = [BLOCKIDS[i] for i, v in enumerate(mm_selected_all) if v != xx_selected_all[i]]
if len(changed) > 0:
print(f" - [{jj+1}] {changed} block{'s' if len(changed) > 1 else ''} added")
mm_selected_all = xx_selected_all
jj += 1
else:
break

for k in range(max_blocks):
if mm_selected_all[k]:
selected_blocks.append(all_blocks[k])
Expand Down

0 comments on commit bfb79fa

Please sign in to comment.