From 9de3aa9dd222c743f7dda9f59745cfbe7b004cf9 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 17 Jan 2024 12:14:00 +0900 Subject: [PATCH] fix para_to_weights() (issue #103) --- sd_modelmixer/hyper.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sd_modelmixer/hyper.py b/sd_modelmixer/hyper.py index 9be8c8b..0d90b15 100644 --- a/sd_modelmixer/hyper.py +++ b/sd_modelmixer/hyper.py @@ -22,14 +22,19 @@ classifiers = get_classifiers() -def para_to_weights(para, weights=None, isxl=False): +def para_to_weights(para, weights=None, alpha=None, isxl=False): BLOCKS = all_blocks(isxl) BLOCKLEN = (12 if not isxl else 9)*2 + 2 weights = {} if weights is None else dict(zip(range(len(weights)), weights)) + alpha = {} for k in para: name = k.split(".") modelidx = ord(name[0].split("_")[1]) - 98 + if name[1] == "alpha": + alpha[modelidx] = para[k] + continue + weight = weights.get(modelidx, [0.0]*BLOCKLEN) j = BLOCKS.index(name[1]) weight[j] = para[k] @@ -40,7 +45,12 @@ def para_to_weights(para, weights=None, isxl=False): for i in weights.keys(): nweights[i] = ",".join([("0" if float(f) == 0.0 else str(f)) for f in weights[i]]) - return nweights + maxid = max(alpha.keys()) + nalpha = [""] * (maxid + 1) + for i in alpha.keys(): + nalpha[i] = alpha[i] + + return nweights, nalpha def normalize_mbw(mbw, isxl): @@ -546,12 +556,13 @@ def hyper_score(localargs): shared.state.end() if best_para is not None: - best_weights = para_to_weights(best_para, weights, isxl) + best_weights, best_alpha = para_to_weights(best_para, weights, alpha, isxl) print(" - Best weights para = ", best_weights, override_uses) + print(" - Best alpha para = ", best_alpha) # setup override weights. will be replaced with mm_weights - shared.modelmixer_overrides = {"weights": best_weights, "uses": override_uses} + shared.modelmixer_overrides = {"weights": best_weights, "alpha": best_alpha, "uses": override_uses} # generate image with the optimized parameter ret = None