Skip to content

Commit

Permalink
update partial block merge method
Browse files Browse the repository at this point in the history
 * more flexible partial update. only the first model(model_a) need to be the same.
 * simplified modified blocks check method.
  • Loading branch information
wkpark committed Nov 16, 2023
1 parent 82ef002 commit d0979bc
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions scripts/model_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,33 +2506,50 @@ def add_difference(theta0, theta1, base, alpha):
if use_unet_partial_update and current is not None:
# check same models used
hashes = current["hashes"]
same_models = True
if len(hashes) != len(mm_models) + 1:
same_models = False
for j, m in enumerate([model_a] + [*mm_models]):
info = sd_models.get_closet_checkpoint_match(m)
if info is None:
same_models = False
break
# only the first model need to be checked
first_model_is_the_same = False

# check model_a
info = sd_models.get_closet_checkpoint_match(model_a)
if info is not None:
if info.shorthash is None:
info.calculate_shorthash()
if j >= len(hashes) or hashes[j] != info.shorthash:
same_models = False
break

if same_models:
print(" - check possible UNet partial update...")

if same_models and current["calcmode"] == mm_calcmodes and current["mode"] == mm_modes:
max_blocks = 26 - (0 if not isxl else 6)

# check changed weights
weights = current["weights"]
while len(weights) > 0:
changed = [False] * len(weights[0])
for j, w in enumerate(mm_weights):
changed |= np.array(weights[j][:max_blocks]) != np.array(w[:max_blocks])
if hashes[0] == info.shorthash:
first_model_is_the_same = True

if first_model_is_the_same:
print(" - check possible UNet partial update...")
max_blocks = 26 - (0 if not isxl else 6)

# check changed weights
weights = current["weights"]
changed = [False] * max_blocks

for j, m in enumerate(mm_models):
info = sd_models.get_closet_checkpoint_match(m)
if info is None:
if info.shorthash is None:
info.calculate_shorthash()

same_model = True
# is it different model?
if j + 1 >= len(hashes) or hashes[j + 1] != info.shorthash:
same_model = False

if same_model and (current["calcmode"][j] != mm_calcmodes[j] or current["mode"][j] != mm_modes[j]):
same_model = False

if same_model:
# check modified block weighs
changed |= np.array(weights[j][:max_blocks]) != np.array(mm_weights[j][:max_blocks])
else:
# check all non zero blocks
if len(weights) > j:
changed |= np.array(weights[j][:max_blocks]) != np.array([0.0]*max_blocks)
changed |= np.array(mm_weights[j][:max_blocks]) != np.array([0.0]*max_blocks)

BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL
print(" - partial changed blocks = ", [BLOCKIDS[k] for k, b in enumerate(changed) if b])
all_blocks = _all_blocks(isxl)
weight_changed_blocks = []
for j, b in enumerate(changed):
Expand Down Expand Up @@ -2573,8 +2590,6 @@ def add_difference(theta0, theta1, base, alpha):
partial_update = True
print(" - UNet partial update mode")

break

# check Rebasin mode
if not isxl and "Rebasin" in calcmodes:
print("Rebasin mode")
Expand Down

0 comments on commit d0979bc

Please sign in to comment.