From d0979bca61ac91f876cc718c395d5f9d86a0fceb Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 17 Nov 2023 01:41:56 +0900 Subject: [PATCH] update partial block merge method * more flexible partial update. only the first model(model_a) need to be the same. * simplified modified blocks check method. --- scripts/model_mixer.py | 67 ++++++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 8988c99..23f4fa7 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -2506,33 +2506,50 @@ def add_difference(theta0, theta1, base, alpha): if use_unet_partial_update and current is not None: # check same models used hashes = current["hashes"] - same_models = True - if len(hashes) != len(mm_models) + 1: - same_models = False - for j, m in enumerate([model_a] + [*mm_models]): - info = sd_models.get_closet_checkpoint_match(m) - if info is None: - same_models = False - break + # only the first model need to be checked + first_model_is_the_same = False + + # check model_a + info = sd_models.get_closet_checkpoint_match(model_a) + if info is not None: if info.shorthash is None: info.calculate_shorthash() - if j >= len(hashes) or hashes[j] != info.shorthash: - same_models = False - break - - if same_models: - print(" - check possible UNet partial update...") - - if same_models and current["calcmode"] == mm_calcmodes and current["mode"] == mm_modes: - max_blocks = 26 - (0 if not isxl else 6) - - # check changed weights - weights = current["weights"] - while len(weights) > 0: - changed = [False] * len(weights[0]) - for j, w in enumerate(mm_weights): - changed |= np.array(weights[j][:max_blocks]) != np.array(w[:max_blocks]) + if hashes[0] == info.shorthash: + first_model_is_the_same = True + + if first_model_is_the_same: + print(" - check possible UNet partial update...") + max_blocks = 26 - (0 if not isxl else 6) + + # check changed weights + weights = current["weights"] + changed = [False] * max_blocks + + for j, m in enumerate(mm_models): + info = sd_models.get_closet_checkpoint_match(m) + if info is None: + if info.shorthash is None: + info.calculate_shorthash() + + same_model = True + # is it different model? + if j + 1 >= len(hashes) or hashes[j + 1] != info.shorthash: + same_model = False + + if same_model and (current["calcmode"][j] != mm_calcmodes[j] or current["mode"][j] != mm_modes[j]): + same_model = False + + if same_model: + # check modified block weighs + changed |= np.array(weights[j][:max_blocks]) != np.array(mm_weights[j][:max_blocks]) + else: + # check all non zero blocks + if len(weights) > j: + changed |= np.array(weights[j][:max_blocks]) != np.array([0.0]*max_blocks) + changed |= np.array(mm_weights[j][:max_blocks]) != np.array([0.0]*max_blocks) + BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL + print(" - partial changed blocks = ", [BLOCKIDS[k] for k, b in enumerate(changed) if b]) all_blocks = _all_blocks(isxl) weight_changed_blocks = [] for j, b in enumerate(changed): @@ -2573,8 +2590,6 @@ def add_difference(theta0, theta1, base, alpha): partial_update = True print(" - UNet partial update mode") - break - # check Rebasin mode if not isxl and "Rebasin" in calcmodes: print("Rebasin mode")