diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index f1b8cae..e6f563a 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -478,8 +478,15 @@ def mm_list_models(): def get_rebasin_perms(mbws, isxl): """all blocks permutations of selected blocks""" - selected = get_selected_blocks(mbws, isxl) - all_blocks = _all_blocks(isxl) + if True in mbws or False in mbws: # already have selected + _selected = mbws + all_blocks = _all_blocks(isxl) + selected = [] + for i, v in enumerate(_selected): + if v: + selected.append(all_blocks[i]) + else: + selected = get_selected_blocks(mbws, isxl) if len(selected) > 0: axes = [] @@ -496,7 +503,24 @@ def get_rebasin_perms(mbws, isxl): return None -def get_rebasin_groups(mbws, isxl): +def get_rebasin_axes(mbws, isxl): + """select all blocks correspond their permutation groups""" + + perms = get_rebasin_perms(mbws, isxl) + if perms is None: + return None + + # get all axes and corresponde blocks + blocks = [] + axes = [] + for perm in perms: + axes += [axes[0] for axes in permutation_spec.perm_to_axes[perm]] + axes = list(set(axes)) + + return axes + + +def _get_rebasin_blocks(mbws, isxl): """select all blocks correspond their permutation groups""" perms = get_rebasin_perms(mbws, isxl) @@ -514,16 +538,15 @@ def get_rebasin_groups(mbws, isxl): MAXLEN = 26 - (0 if not isxl else 6) BLOCKLEN = 12 - (0 if not isxl else 3) BLOCKOFFSET = 13 if not isxl else 10 - new_selected = [False]*MAXLEN + selected = [False]*MAXLEN BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL all_blocks = _all_blocks(isxl) - blocknames = [] for j, block in enumerate(all_blocks): if any(block in axe for axe in axes): - blocknames.append(BLOCKIDS[j]) + selected[j] = True - return blocknames + return selected def get_device(): @@ -1939,6 +1962,22 @@ def load_state_dict(checkpoint_info): mm_selected_all[k] = mm_selected_all[k] or elemental_selected[k] all_blocks = _all_blocks(isxl) + BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL + + # get all blocks affected by same perm groups by rebasin merge + if not isxl and "Rebasin" in mm_calcmodes: + print("check affected permutation blocks by rebasin merge...") + jj = 0 + while True: + xx_selected_all = _get_rebasin_blocks(mm_selected_all, isxl) + changed = [BLOCKIDS[i] for i, v in enumerate(mm_selected_all) if v != xx_selected_all[i]] + if len(changed) > 0: + print(f" - [{jj+1}] {changed} block{'s' if len(changed) > 1 else ''} added") + mm_selected_all = xx_selected_all + jj += 1 + else: + break + for k in range(max_blocks): if mm_selected_all[k]: selected_blocks.append(all_blocks[k])