From 06bf949e695ade45ab362cc276ff9e0d7af67d0d Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sat, 21 Jan 2023 14:26:37 -0500 Subject: [PATCH] black and ruff linting pass cleanup --- scripts/mbw/merge_block_weighted.py | 99 +- scripts/mbw/ui_mbw.py | 474 ++++++-- scripts/mbw_each/merge_block_weighted_mod.py | 83 +- scripts/mbw_each/ui_mbw_each.py | 1095 +++++++++++++++--- scripts/mbw_util/merge_history.py | 102 +- scripts/mbw_util/preset_weights.py | 11 +- scripts/merge_block_weighted_extension.py | 7 +- 7 files changed, 1458 insertions(+), 413 deletions(-) diff --git a/scripts/mbw/merge_block_weighted.py b/scripts/mbw/merge_block_weighted.py index 226ba58..9a33ac0 100644 --- a/scripts/mbw/merge_block_weighted.py +++ b/scripts/mbw/merge_block_weighted.py @@ -7,7 +7,8 @@ # bbc-mc import os -import argparse + +# import argparse import re import torch from tqdm import tqdm @@ -22,21 +23,30 @@ KEY_POSITION_IDS = "cond_stage_model.transformer.text_model.embeddings.position_ids" + def dprint(str, flg): if flg: print(str) -def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, - output_file="", allow_overwrite=False, verbose=False, - save_as_safetensors=False, - save_as_half=False, - skip_position_ids=0 - ): - if weights is None: - weights = None +def merge( + input_weights: str, + model_0, + model_1, + device="cpu", + base_alpha=0.5, + output_file="", + allow_overwrite=False, + verbose=False, + save_as_safetensors=False, + save_as_half=False, + skip_position_ids=0, +): + if input_weights is None: + weights = [] else: - weights = [float(w) for w in weights.split(',')] + weights = [float(w) for w in input_weights.split(",")] + if len(weights) != NUM_TOTAL_BLOCKS: _err_msg = f"weights value must be {NUM_TOTAL_BLOCKS}." print(_err_msg) @@ -46,8 +56,10 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, def load_model(_model, _device="cpu"): model_info = sd_models.get_closet_checkpoint_match(_model) - if model_info: - model_file = model_info.filename + if model_info is None: + raise RuntimeError("invalid model filename") + + model_file = model_info.filename return sd_models.read_state_dict(model_file, map_location=_device) print("loading", model_0) @@ -69,42 +81,48 @@ def load_model(_model, _device="cpu"): print(_err_msg) return False, _err_msg - re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 - re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 - re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 + re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 + re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 + re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 print(" merging ...") - dprint(f"-- start Stage 1/2 --", verbose) + dprint("-- start Stage 1/2 --", verbose) count_target_of_basealpha = 0 - for key in (tqdm(theta_0.keys(), desc="Stage 1/2")): + for key in tqdm(theta_0.keys(), desc="Stage 1/2"): if "model" in key and key in theta_1: if KEY_POSITION_IDS in key: print(key) if skip_position_ids == 1: - print(f" modelA: skip 'position_ids' : dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: skip 'position_ids' : dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue elif skip_position_ids == 2: theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) - print(f" modelA: reset 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: reset 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue else: - print(f" modelA: 'position_ids' key found. do nothing : {skip_position_ids}: dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: 'position_ids' key found. do nothing : {skip_position_ids}: dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f" key : {key}", verbose) current_alpha = alpha # check weighted and U-Net or not - if weights is not None and 'model.diffusion_model.' in key: + if weights is not None and "model.diffusion_model." in key: # check block index weight_index = -1 - if 'time_embed' in key: - weight_index = 0 # before input blocks - elif '.out.' in key: - weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks + if "time_embed" in key: + weight_index = 0 # before input blocks + elif ".out." in key: + weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks else: m = re_inp.search(key) if m: @@ -118,11 +136,14 @@ def load_model(_model, _device="cpu"): m = re_out.search(key) if m: out_idx = int(m.groups()[0]) - weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx + weight_index = ( + NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx + ) if weight_index >= NUM_TOTAL_BLOCKS: print(f"error. illegal block index: {key}") return False, "" + if weight_index >= 0: current_alpha = weights[weight_index] dprint(f"weighted '{key}': {current_alpha}", verbose) @@ -130,7 +151,9 @@ def load_model(_model, _device="cpu"): count_target_of_basealpha = count_target_of_basealpha + 1 dprint(f"base_alpha applied: [{key}]", verbose) - theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] + theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[ + key + ] if save_as_half: theta_0[key] = theta_0[key].half() @@ -138,25 +161,31 @@ def load_model(_model, _device="cpu"): else: dprint(f" key - {key}", verbose) - dprint(f"-- start Stage 2/2 --", verbose) + dprint("-- start Stage 2/2 --", verbose) for key in tqdm(theta_1.keys(), desc="Stage 2/2"): if "model" in key and key not in theta_0: if KEY_POSITION_IDS in key: if skip_position_ids == 1: - print(f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue elif skip_position_ids == 2: theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) - print(f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue else: - print(f" modelB: 'position_ids' key found. do nothing : {skip_position_ids}") + print( + f" modelB: 'position_ids' key found. do nothing : {skip_position_ids}" + ) dprint(f" key : {key}", verbose) - theta_0.update({key:theta_1[key]}) + theta_0.update({key: theta_1[key]}) if save_as_half: theta_0[key] = theta_0[key].half() @@ -171,10 +200,14 @@ def load_model(_model, _device="cpu"): if save_as_safetensors and extension.lower() != ".safetensors": output_file = output_file + ".safetensors" import safetensors.torch + safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"}) else: torch.save({"state_dict": theta_0}, output_file) print("Done!") - return True, f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times." + return ( + True, + f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times.", + ) diff --git a/scripts/mbw/ui_mbw.py b/scripts/mbw/ui_mbw.py index ac71566..3b3b7a0 100644 --- a/scripts/mbw/ui_mbw.py +++ b/scripts/mbw/ui_mbw.py @@ -1,18 +1,19 @@ -import gradio as gr import os import re +import gradio as gr from modules import sd_models, shared -from tqdm import tqdm + try: from modules import hashes from modules.sd_models import CheckpointInfo except: + print("could not import hashes or CheckpointInfo from SDwebui") pass from scripts.mbw.merge_block_weighted import merge -from scripts.mbw_util.preset_weights import PresetWeights from scripts.mbw_util.merge_history import MergeHistory +from scripts.mbw_util.preset_weights import PresetWeights presetWeights = PresetWeights() mergeHistory = MergeHistory() @@ -24,42 +25,90 @@ def on_ui_tabs(): with gr.Column(variant="panel"): html_output_block_weight_info = gr.HTML() with gr.Row(): - btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") + btn_do_merge_block_weighted = gr.Button( + value="Run Merge", variant="primary" + ) btn_clear_weight = gr.Button(value="Clear values") btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") with gr.Column(): - dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list()) - txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25") - btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") + dd_preset_weight = gr.Dropdown( + label="Preset Weights", choices=presetWeights.get_preset_name_list() + ) + txt_block_weight = gr.Text( + label="Weight values", + placeholder="Put weight sets. float number x 25", + ) + btn_apply_block_weithg_from_txt = gr.Button( + value="Apply block weight from text", variant="primary" + ) with gr.Row(): - sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0) - chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) - chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) + sl_base_alpha = gr.Slider( + label="base_alpha", minimum=0, maximum=1, step=0.01, value=0 + ) + chk_verbose_mbw = gr.Checkbox( + label="verbose console output", value=False + ) + chk_allow_overwrite = gr.Checkbox( + label="Allow overwrite output-model", value=False + ) with gr.Row(): with gr.Column(scale=3): with gr.Row(): - chk_save_as_half = gr.Checkbox(label="Save as half", value=False) - chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False) + chk_save_as_half = gr.Checkbox( + label="Save as half", value=False + ) + chk_save_as_safetensors = gr.Checkbox( + label="Save as safetensors", value=False + ) with gr.Column(scale=4): - radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids", choices=["None", "Skip", "Force Reset"], value="None", type="index") + radio_position_ids = gr.Radio( + label="Skip/Reset CLIP position_ids", + choices=["None", "Skip", "Force Reset"], + value="None", + type="index", + ) with gr.Row(): model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles()) model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles()) txt_model_O = gr.Text(label="Output Model Name") with gr.Row(): with gr.Column(): - sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5) - sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5) + sl_IN_00 = gr.Slider( + label="IN00", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_01 = gr.Slider( + label="IN01", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_02 = gr.Slider( + label="IN02", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_03 = gr.Slider( + label="IN03", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_04 = gr.Slider( + label="IN04", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_05 = gr.Slider( + label="IN05", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_06 = gr.Slider( + label="IN06", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_07 = gr.Slider( + label="IN07", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_08 = gr.Slider( + label="IN08", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_09 = gr.Slider( + label="IN09", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_10 = gr.Slider( + label="IN10", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_IN_11 = gr.Slider( + label="IN11", minimum=0, maximum=1, step=0.01, value=0.5 + ) with gr.Column(): gr.Slider(visible=False) gr.Slider(visible=False) @@ -72,53 +121,155 @@ def on_ui_tabs(): gr.Slider(visible=False) gr.Slider(visible=False) gr.Slider(visible=False) - sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="mbw_sl_M00") + sl_M_00 = gr.Slider( + label="M00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="mbw_sl_M00", + ) with gr.Column(): - sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5) - sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5) + sl_OUT_11 = gr.Slider( + label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_10 = gr.Slider( + label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_09 = gr.Slider( + label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_08 = gr.Slider( + label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_07 = gr.Slider( + label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_06 = gr.Slider( + label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_05 = gr.Slider( + label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_04 = gr.Slider( + label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_03 = gr.Slider( + label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_02 = gr.Slider( + label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_01 = gr.Slider( + label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5 + ) + sl_OUT_00 = gr.Slider( + label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5 + ) sl_IN = [ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11] + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, + ] sl_MID = [sl_M_00] sl_OUT = [ - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11] + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + ] # Events def onclick_btn_do_merge_block_weighted( - model_A, model_B, - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + model_A, + model_B, + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite, - chk_save_as_safetensors, chk_save_as_half, - radio_position_ids + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + txt_model_O, + sl_base_alpha, + chk_verbose_mbw, + chk_allow_overwrite, + chk_save_as_safetensors, + chk_save_as_half, + radio_position_ids, ): # debug output - print( "#### Merge Block Weighted ####") + print("#### Merge Block Weighted ####") _weights = ",".join( - [str(x) for x in [ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, - sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11 - ]]) + [ + str(x) + for x in [ + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, + sl_M_00, + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + ] + ] + ) # if not model_A or not model_B: return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]") @@ -139,8 +290,10 @@ def onclick_btn_do_merge_block_weighted( else: _model_B_info = "" - def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False): - output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename) + def validate_output_filename( + output_filename, save_as_safetensors=False, save_as_half=False + ): + output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', "-", output_filename) filename_body, filename_ext = os.path.splitext(output_filename) _ret = output_filename _footer = "-half" if save_as_half else "" @@ -152,14 +305,26 @@ def validate_output_filename(output_filename, save_as_safetensors=False, save_as _ret = f"{output_filename}{_footer}.ckpt" return _ret - model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O - model_O = validate_output_filename(model_O, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half) + model_O = ( + f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" + if txt_model_O == "" + else txt_model_O + ) + model_O = validate_output_filename( + model_O, + save_as_safetensors=chk_save_as_safetensors, + save_as_half=chk_save_as_half, + ) - _output = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) + _output = os.path.join( + shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O + ) if not chk_allow_overwrite: if os.path.exists(_output): - _err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort." + _err_msg = ( + "ERROR: output_file already exists. overwrite not allowed. abort." + ) print(_err_msg) return gr.update(value=f"{_err_msg} [{_output}]") print(f" model_0 : {model_A}") @@ -169,20 +334,28 @@ def validate_output_filename(output_filename, save_as_safetensors=False, save_as print(f" weights : {_weights}") print(f" skip ids : {radio_position_ids} : 0:None, 1:Skip, 2:Reset") - result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, - base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw, + result, ret_message = merge( + input_weights=_weights, + model_0=model_A, + model_1=model_B, + allow_overwrite=chk_allow_overwrite, + base_alpha=sl_base_alpha, + output_file=_output, + verbose=chk_verbose_mbw, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half, - skip_position_ids=radio_position_ids - ) + skip_position_ids=radio_position_ids, + ) if result: - ret_html = "merged.
" \ - + f"{model_A}
" \ - + f"{model_B}
" \ - + f"{model_O}
" \ - + f"base_alpha={sl_base_alpha}
" \ + ret_html = ( + "merged.
" + + f"{model_A}
" + + f"{model_B}
" + + f"{model_O}
" + + f"base_alpha={sl_base_alpha}
" + f"Weight_values={_weights}
" + ) print("merged.") else: ret_html = ret_message @@ -194,8 +367,10 @@ def validate_output_filename(output_filename, save_as_safetensors=False, save_as model_B_info = sd_models.get_closet_checkpoint_match(model_B) model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(_output)) if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: - model_O_info:CheckpointInfo = model_O_info - model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) + model_O_info: CheckpointInfo = model_O_info + model_O_info.sha256 = hashes.sha256( + model_O_info.filename, "checkpoint/" + model_O_info.title + ) _names = presetWeights.find_names_by_weight(_weights) if _names and len(_names) > 0: weight_name = _names[0] @@ -204,87 +379,156 @@ def validate_output_filename(output_filename, save_as_safetensors=False, save_as def model_name(model_info): return model_info.name if hasattr(model_info, "name") else model_info.title + def model_sha256(model_info): return model_info.sha256 if hasattr(model_info, "sha256") else "" + mergeHistory.add_history( - model_name(model_A_info), - model_A_info.hash, - model_sha256(model_A_info), - model_name(model_B_info), - model_B_info.hash, - model_sha256(model_B_info), - model_name(model_O_info), - model_O_info.hash, - model_sha256(model_O_info), - sl_base_alpha, - _weights, - "", - weight_name - ) + model_name(model_A_info), + model_A_info.hash, + model_sha256(model_A_info), + model_name(model_B_info), + model_B_info.hash, + model_sha256(model_B_info), + model_name(model_O_info), + model_O_info.hash, + model_sha256(model_O_info), + sl_base_alpha, + _weights, + "", + weight_name, + ) return gr.update(value=f"{ret_html}") + btn_do_merge_block_weighted.click( fn=onclick_btn_do_merge_block_weighted, inputs=[model_A, model_B] - + sl_IN + sl_MID + sl_OUT - + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] - + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], - outputs=[html_output_block_weight_info] + + sl_IN + + sl_MID + + sl_OUT + + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] + + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], + outputs=[html_output_block_weight_info], ) btn_clear_weight.click( fn=lambda: [gr.update(value=0.5) for _ in range(25)], inputs=[], outputs=[ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + ], ) def on_change_dd_preset_weight(dd_preset_weight): _weights = presetWeights.find_weight_by_name(dd_preset_weight) _ret = on_btn_apply_block_weight_from_txt(_weights) return [gr.update(value=_weights)] + _ret + dd_preset_weight.change( fn=on_change_dd_preset_weight, inputs=[dd_preset_weight], - outputs=[txt_block_weight, - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + outputs=[ + txt_block_weight, + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + ], ) def on_btn_reload_checkpoint_mbw(): sd_models.list_models() - return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] + return [ + gr.update(choices=sd_models.checkpoint_tiles()), + gr.update(choices=sd_models.checkpoint_tiles()), + ] + btn_reload_checkpoint_mbw.click( - fn=on_btn_reload_checkpoint_mbw, - inputs=[], - outputs=[model_A, model_B] + fn=on_btn_reload_checkpoint_mbw, inputs=[], outputs=[model_A, model_B] ) def on_btn_apply_block_weight_from_txt(txt_block_weight): if not txt_block_weight or txt_block_weight == "": return [gr.update() for _ in range(25)] _list = [x.strip() for x in txt_block_weight.split(",")] - if(len(_list) != 25): + if len(_list) != 25: return [gr.update() for _ in range(25)] return [gr.update(value=x) for x in _list] + btn_apply_block_weithg_from_txt.click( fn=on_btn_apply_block_weight_from_txt, inputs=[txt_block_weight], outputs=[ - sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, - sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, + sl_IN_00, + sl_IN_01, + sl_IN_02, + sl_IN_03, + sl_IN_04, + sl_IN_05, + sl_IN_06, + sl_IN_07, + sl_IN_08, + sl_IN_09, + sl_IN_10, + sl_IN_11, sl_M_00, - sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, - sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, - ] + sl_OUT_00, + sl_OUT_01, + sl_OUT_02, + sl_OUT_03, + sl_OUT_04, + sl_OUT_05, + sl_OUT_06, + sl_OUT_07, + sl_OUT_08, + sl_OUT_09, + sl_OUT_10, + sl_OUT_11, + ], ) - diff --git a/scripts/mbw_each/merge_block_weighted_mod.py b/scripts/mbw_each/merge_block_weighted_mod.py index d7f6e09..ba217a9 100644 --- a/scripts/mbw_each/merge_block_weighted_mod.py +++ b/scripts/mbw_each/merge_block_weighted_mod.py @@ -7,7 +7,6 @@ # bbc-mc import os -import argparse import re import torch from tqdm import tqdm @@ -22,18 +21,26 @@ KEY_POSITION_IDS = "cond_stage_model.transformer.text_model.embeddings.position_ids" + def dprint(str, flg): if flg: print(str) -def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5, - output_file="", allow_overwrite=False, verbose=False, - save_as_safetensors=False, - save_as_half=False, - skip_position_ids=0, - ): - +def merge( + weight_A: list, + weight_B: list, + model_0, + model_1, + device="cpu", + base_alpha=0.5, + output_file="", + allow_overwrite=False, + verbose=False, + save_as_safetensors=False, + save_as_half=False, + skip_position_ids=0, +): def _check_arg_weight(weight): if weight is None: return None @@ -90,27 +97,35 @@ def load_model(_model, _device): print("loading", model_1) theta_1 = load_model(model_1, device) - re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 - re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 - re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 + re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 + re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 + re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 - dprint(f"-- start Stage 1/2 --", verbose) + dprint("-- start Stage 1/2 --", verbose) count_target_of_basealpha = 0 - for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()): + for key in ( + tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys() + ): if "model" in key and key in theta_1: if KEY_POSITION_IDS in key: if skip_position_ids == 1: - print(f" modelA: skip 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: skip 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue elif skip_position_ids == 2: theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) - print(f" modelA: reset 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: reset 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue else: - print(f" modelA: key found. do nothing: dtype:{theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelA: key found. do nothing: dtype:{theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f" key : {key}", verbose) @@ -119,14 +134,14 @@ def load_model(_model, _device): current_alpha_I = 0 # check weighted and U-Net or not - if weight_A is not None and 'model.diffusion_model.' in key: + if weight_A is not None and "model.diffusion_model." in key: # check block index weight_index = -1 - if 'time_embed' in key: - weight_index = 0 # before input blocks - elif '.out.' in key: - weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks + if "time_embed" in key: + weight_index = 0 # before input blocks + elif ".out." in key: + weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks else: m = re_inp.search(key) if m: @@ -140,7 +155,9 @@ def load_model(_model, _device): m = re_out.search(key) if m: out_idx = int(m.groups()[0]) - weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx + weight_index = ( + NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx + ) if weight_index >= NUM_TOTAL_BLOCKS: print(f"error. illegal block index: {key}") @@ -149,7 +166,9 @@ def load_model(_model, _device): current_alpha_B = weight_B[weight_index] current_alpha_I = 1 - current_alpha_A - current_alpha_B if verbose: - print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}") + print( + f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}" + ) # create I tensor tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype) @@ -164,25 +183,29 @@ def load_model(_model, _device): else: dprint(f" key - {key}", verbose) - dprint(f"-- start Stage 2/2 --", verbose) + dprint("-- start Stage 2/2 --", verbose) for key in tqdm(theta_1.keys(), desc="Stage 2/2"): if "model" in key and key not in theta_0: if KEY_POSITION_IDS in key: if skip_position_ids == 1: - print(f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue elif skip_position_ids == 2: theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) - print(f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}") + print( + f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}" + ) dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) continue else: print(f" modelB: key found. do nothing : {skip_position_ids}") dprint(f" key : {key}", verbose) - theta_0.update({key:theta_1[key]}) + theta_0.update({key: theta_1[key]}) if save_as_half: theta_0[key] = theta_0[key].half() @@ -197,10 +220,14 @@ def load_model(_model, _device): if save_as_safetensors and extension.lower() != ".safetensors": output_file = output_file + ".safetensors" import safetensors.torch + safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"}) else: torch.save({"state_dict": theta_0}, output_file) print("Done!") - return True, f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times." + return ( + True, + f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times.", + ) diff --git a/scripts/mbw_each/ui_mbw_each.py b/scripts/mbw_each/ui_mbw_each.py index 2a3f366..0bbe5a4 100644 --- a/scripts/mbw_each/ui_mbw_each.py +++ b/scripts/mbw_each/ui_mbw_each.py @@ -3,11 +3,12 @@ import re from modules import sd_models, shared -from tqdm import tqdm + try: from modules import hashes from modules.sd_models import CheckpointInfo except: + print("could not import hashes or CheckpointInfo from SDwebui") pass from scripts.mbw_each.merge_block_weighted_mod import merge @@ -23,58 +24,256 @@ def on_ui_tabs(): with gr.Row(): with gr.Column(variant="panel"): with gr.Row(): - txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.") + txt_multi_process_cmd = gr.TextArea( + label="Multi Proc Cmd", placeholder="Keep empty if dont use." + ) html_output_block_weight_info = gr.HTML() with gr.Row(): - btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") + btn_do_merge_block_weighted = gr.Button( + value="Run Merge", variant="primary" + ) btn_clear_weighted = gr.Button(value="Clear values") btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") with gr.Column(): - dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list()) - txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25") - btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") + dd_preset_weight = gr.Dropdown( + label="Preset_Weights", choices=presetWeights.get_preset_name_list() + ) + txt_block_weight = gr.Text( + label="Weight_values", + placeholder="Put weight sets. float number x 25", + ) + btn_apply_block_weithg_from_txt = gr.Button( + value="Apply block weight from text", variant="primary" + ) with gr.Row(): - sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0) - chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) - chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) + sl_base_alpha = gr.Slider( + label="base_alpha", minimum=0, maximum=1, step=0.01, value=0 + ) + chk_verbose_mbw = gr.Checkbox( + label="verbose console output", value=False + ) + chk_allow_overwrite = gr.Checkbox( + label="Allow overwrite output-model", value=False + ) with gr.Row(): with gr.Column(scale=3): with gr.Row(): - chk_save_as_half = gr.Checkbox(label="Save as half", value=False) - chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False) + chk_save_as_half = gr.Checkbox( + label="Save as half", value=False + ) + chk_save_as_safetensors = gr.Checkbox( + label="Save as safetensors", value=False + ) with gr.Column(scale=4): - radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids", choices=["None", "Skip", "Force Reset"], value="None", type="index") + radio_position_ids = gr.Radio( + label="Skip/Reset CLIP position_ids", + choices=["None", "Skip", "Force Reset"], + value="None", + type="index", + ) with gr.Row(): - dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles()) - dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles()) + dd_model_A = gr.Dropdown( + label="Model_A", choices=sd_models.checkpoint_tiles() + ) + dd_model_B = gr.Dropdown( + label="Model_B", choices=sd_models.checkpoint_tiles() + ) txt_model_O = gr.Text(label="(O)Output Model Name") with gr.Row(): with gr.Column(): - sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_00") - sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_01") - sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_02") - sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_03") - sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_04") - sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_05") - sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_06") - sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_07") - sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_08") - sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_09") - sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_10") - sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_11") + sl_IN_A_00 = gr.Slider( + label="IN_A_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_00", + ) + sl_IN_A_01 = gr.Slider( + label="IN_A_01", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_01", + ) + sl_IN_A_02 = gr.Slider( + label="IN_A_02", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_02", + ) + sl_IN_A_03 = gr.Slider( + label="IN_A_03", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_03", + ) + sl_IN_A_04 = gr.Slider( + label="IN_A_04", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_04", + ) + sl_IN_A_05 = gr.Slider( + label="IN_A_05", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_05", + ) + sl_IN_A_06 = gr.Slider( + label="IN_A_06", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_06", + ) + sl_IN_A_07 = gr.Slider( + label="IN_A_07", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_07", + ) + sl_IN_A_08 = gr.Slider( + label="IN_A_08", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_08", + ) + sl_IN_A_09 = gr.Slider( + label="IN_A_09", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_09", + ) + sl_IN_A_10 = gr.Slider( + label="IN_A_10", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_10", + ) + sl_IN_A_11 = gr.Slider( + label="IN_A_11", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_A_11", + ) with gr.Column(): - sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_00") - sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_01") - sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_02") - sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_03") - sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_04") - sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_05") - sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_06") - sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_07") - sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_08") - sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_09") - sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_10") - sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_11") + sl_IN_B_00 = gr.Slider( + label="IN_B_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_00", + ) + sl_IN_B_01 = gr.Slider( + label="IN_B_01", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_01", + ) + sl_IN_B_02 = gr.Slider( + label="IN_B_02", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_02", + ) + sl_IN_B_03 = gr.Slider( + label="IN_B_03", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_03", + ) + sl_IN_B_04 = gr.Slider( + label="IN_B_04", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_04", + ) + sl_IN_B_05 = gr.Slider( + label="IN_B_05", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_05", + ) + sl_IN_B_06 = gr.Slider( + label="IN_B_06", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_06", + ) + sl_IN_B_07 = gr.Slider( + label="IN_B_07", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_07", + ) + sl_IN_B_08 = gr.Slider( + label="IN_B_08", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_08", + ) + sl_IN_B_09 = gr.Slider( + label="IN_B_09", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_09", + ) + sl_IN_B_10 = gr.Slider( + label="IN_B_10", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_10", + ) + sl_IN_B_11 = gr.Slider( + label="IN_B_11", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_IN_B_11", + ) with gr.Column(): gr.Slider(visible=False) gr.Slider(visible=False) @@ -87,7 +286,14 @@ def on_ui_tabs(): gr.Slider(visible=False) gr.Slider(visible=False) gr.Slider(visible=False) - sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_M_A_00") + sl_M_A_00 = gr.Slider( + label="M_A_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_M_A_00", + ) with gr.Column(): gr.Slider(visible=False) gr.Slider(visible=False) @@ -100,33 +306,208 @@ def on_ui_tabs(): gr.Slider(visible=False) gr.Slider(visible=False) gr.Slider(visible=False) - sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_M_B_00") + sl_M_B_00 = gr.Slider( + label="M_B_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_M_B_00", + ) with gr.Column(): - sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_11") - sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_10") - sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_09") - sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_08") - sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_07") - sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_06") - sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_05") - sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_04") - sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_03") - sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_02") - sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_01") - sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_00") + sl_OUT_A_11 = gr.Slider( + label="OUT_A_11", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_11", + ) + sl_OUT_A_10 = gr.Slider( + label="OUT_A_10", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_10", + ) + sl_OUT_A_09 = gr.Slider( + label="OUT_A_09", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_09", + ) + sl_OUT_A_08 = gr.Slider( + label="OUT_A_08", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_08", + ) + sl_OUT_A_07 = gr.Slider( + label="OUT_A_07", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_07", + ) + sl_OUT_A_06 = gr.Slider( + label="OUT_A_06", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_06", + ) + sl_OUT_A_05 = gr.Slider( + label="OUT_A_05", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_05", + ) + sl_OUT_A_04 = gr.Slider( + label="OUT_A_04", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_04", + ) + sl_OUT_A_03 = gr.Slider( + label="OUT_A_03", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_03", + ) + sl_OUT_A_02 = gr.Slider( + label="OUT_A_02", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_02", + ) + sl_OUT_A_01 = gr.Slider( + label="OUT_A_01", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_01", + ) + sl_OUT_A_00 = gr.Slider( + label="OUT_A_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_A_00", + ) with gr.Column(): - sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_11") - sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_10") - sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_09") - sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_08") - sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_07") - sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_06") - sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_05") - sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_04") - sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_03") - sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_02") - sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_01") - sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_00") + sl_OUT_B_11 = gr.Slider( + label="OUT_B_11", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_11", + ) + sl_OUT_B_10 = gr.Slider( + label="OUT_B_10", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_10", + ) + sl_OUT_B_09 = gr.Slider( + label="OUT_B_09", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_09", + ) + sl_OUT_B_08 = gr.Slider( + label="OUT_B_08", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_08", + ) + sl_OUT_B_07 = gr.Slider( + label="OUT_B_07", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_07", + ) + sl_OUT_B_06 = gr.Slider( + label="OUT_B_06", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_06", + ) + sl_OUT_B_05 = gr.Slider( + label="OUT_B_05", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_05", + ) + sl_OUT_B_04 = gr.Slider( + label="OUT_B_04", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_04", + ) + sl_OUT_B_03 = gr.Slider( + label="OUT_B_03", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_03", + ) + sl_OUT_B_02 = gr.Slider( + label="OUT_B_02", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_02", + ) + sl_OUT_B_01 = gr.Slider( + label="OUT_B_01", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_01", + ) + sl_OUT_B_00 = gr.Slider( + label="OUT_B_00", + minimum=0, + maximum=1, + step=0.01, + value=0.5, + elem_id="sl_OUT_B_00", + ) # Footer gr.HTML( @@ -140,59 +521,196 @@ def on_ui_tabs(): ) sl_A_IN = [ - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11] + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, + ] sl_A_MID = [sl_M_A_00] sl_A_OUT = [ - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11] + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + ] sl_B_IN = [ - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11] + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, + ] sl_B_MID = [sl_M_B_00] sl_B_OUT = [ - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11] - + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + ] # Events def onclick_btn_do_merge_block_weighted( - dd_model_A, dd_model_B, txt_multi_process_cmd, - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + dd_model_A, + dd_model_B, + txt_multi_process_cmd, + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, sl_M_A_00, - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, sl_M_B_00, - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, - txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite, - chk_save_as_safetensors, chk_save_as_half, - radio_position_ids + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + txt_model_O, + sl_base_alpha, + chk_verbose_mbw, + chk_allow_overwrite, + chk_save_as_safetensors, + chk_save_as_half, + radio_position_ids, ): base_alpha = sl_base_alpha _weight_A = ",".join( - [str(x) for x in [ - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, - sl_M_A_00, - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, - ]]) + [ + str(x) + for x in [ + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, + sl_M_A_00, + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + ] + ] + ) _weight_B = ",".join( - [str(x) for x in [ - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, - sl_M_B_00, - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, - ]]) + [ + str(x) + for x in [ + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, + sl_M_B_00, + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + ] + ] + ) # debug output - print( "#### Merge Block Weighted : Each ####") + print("#### Merge Block Weighted : Each ####") if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "": _err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]" @@ -202,7 +720,7 @@ def onclick_btn_do_merge_block_weighted( ret_html = "" if txt_multi_process_cmd != "": # need multi-merge - _lines = txt_multi_process_cmd.split('\n') + _lines = txt_multi_process_cmd.split("\n") print(f"check multi-merge. {len(_lines)} lines found.") for line_index, _line in enumerate(_lines): if _line == "": @@ -211,14 +729,19 @@ def onclick_btn_do_merge_block_weighted( _items = [x.strip() for x in _line.split(",") if x != ""] if len(_items) > 0: ret_html += _run_merge( - weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, - allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, + weight_A=_weight_A, + weight_B=_weight_B, + model_0=dd_model_A, + model_1=dd_model_B, + allow_overwrite=chk_allow_overwrite, + base_alpha=base_alpha, + model_Output=txt_model_O, verbose=chk_verbose_mbw, params=_items, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half, - skip_position_ids=radio_position_ids - ) + skip_position_ids=radio_position_ids, + ) else: _ret = f" multi-merge text found, but invalid params. skipped :[{_line}]" ret_html += _ret @@ -226,36 +749,56 @@ def onclick_btn_do_merge_block_weighted( else: # normal merge ret_html += _run_merge( - weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, - allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, + weight_A=_weight_A, + weight_B=_weight_B, + model_0=dd_model_A, + model_1=dd_model_B, + allow_overwrite=chk_allow_overwrite, + base_alpha=base_alpha, + model_Output=txt_model_O, verbose=chk_verbose_mbw, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half, - skip_position_ids=radio_position_ids - ) + skip_position_ids=radio_position_ids, + ) sd_models.list_models() - print( "#### All merge process done. ####") + print("#### All merge process done. ####") return gr.update(value=f"{ret_html}") + btn_do_merge_block_weighted.click( fn=onclick_btn_do_merge_block_weighted, inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] - + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT - + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] - + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], - outputs=[html_output_block_weight_info] + + sl_A_IN + + sl_A_MID + + sl_A_OUT + + sl_B_IN + + sl_B_MID + + sl_B_OUT + + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] + + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], + outputs=[html_output_block_weight_info], ) - def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, - model_Output="", verbose=False, params=[], + def _run_merge( + weight_A, + weight_B, + model_0, + model_1, + allow_overwrite=False, + base_alpha=0, + model_Output="", + verbose=False, + params=[], save_as_safetensors=False, save_as_half=False, skip_position_ids=0, + ): + def validate_output_filename( + output_filename, save_as_safetensors=False, save_as_half=False ): - - def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False): - output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename) + output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', "-", output_filename) filename_body, filename_ext = os.path.splitext(output_filename) _ret = output_filename _footer = "-half" if save_as_half else "" @@ -282,39 +825,61 @@ def validate_output_filename(output_filename, save_as_safetensors=False, save_as _model_name = _model_info.title.split(" ")[0] if _model_name and _model_name.strip() != "": if _item_l.lower() == "model_a": - print(f" * Model changed: {model_0} -> {_model_info.title}") + print( + f" * Model changed: {model_0} -> {_model_info.title}" + ) model_0 = _model_info.title elif _item_l.lower() == "model_b": - print(f" * Model changed: {model_1} -> {_model_info.title}") + print( + f" * Model changed: {model_1} -> {_model_info.title}" + ) model_1 = _model_info.title elif _item_l.lower() == "preset_weights": _weights = presetWeights.find_weight_by_name(_item_r) - if _weights != "" and len(_weights.split(',')) == 25: + if _weights != "" and len(_weights.split(",")) == 25: print(f" * Weights changed by preset-name: {_item_r}") weight_B = _weights - weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')]) + weight_A = ",".join( + [str(1 - float(x)) for x in _weights.split(",")] + ) else: - print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") + print( + f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]" + ) elif _item_l.lower() == "weight_values": _weights = _item_r.strip() - if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator. + if ( + _weights != "" and len(_weights.split(" ")) == 25 + ): # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator. print(f" * Weights changed: {_item_r}") weight_B = _weights - weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')]) + weight_A = ",".join( + [str(1 - float(x)) for x in _weights.split(" ")] + ) else: - print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") + print( + f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]" + ) elif _item_l.lower() == "base_alpha": if float(_item_r) >= 0: - print(f" * base_alpha changed: {base_alpha} -> {_item_r}") + print( + f" * base_alpha changed: {base_alpha} -> {_item_r}" + ) base_alpha = float(_item_r) elif _item_l.upper() == "O": if _item_r.strip() != "": - _ret = validate_output_filename(_item_r.strip(), save_as_safetensors=save_as_safetensors, save_as_half=save_as_half) - print(f" * Output filename changed:[{model_O}] -> [{_ret}]") + _ret = validate_output_filename( + _item_r.strip(), + save_as_safetensors=save_as_safetensors, + save_as_half=save_as_half, + ) + print( + f" * Output filename changed:[{model_O}] -> [{_ret}]" + ) model_O = _ret elif len(_item_l.split("_")) == 3: @@ -341,7 +906,9 @@ def _apply_val(key, weight, index, new_value): elif _AB == "B": weight_B = _apply_val(_AB, weight_B, _index, _item_r) else: - print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]") + print( + f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]" + ) # # Prepare params before run merge @@ -357,9 +924,17 @@ def _apply_val(key, weight, index, new_value): if model_O == "": _a = os.path.splitext(os.path.basename(_model_A_name))[0] _b = os.path.splitext(os.path.basename(_model_B_name))[0] - model_O = f"bw-merge-{_a}-{_b}-{base_alpha}" if model_Output == "" else model_Output - model_O = validate_output_filename(model_O, save_as_safetensors=save_as_safetensors, save_as_half=save_as_half) - output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) + model_O = ( + f"bw-merge-{_a}-{_b}-{base_alpha}" + if model_Output == "" + else model_Output + ) + model_O = validate_output_filename( + model_O, save_as_safetensors=save_as_safetensors, save_as_half=save_as_half + ) + output_file = os.path.join( + shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O + ) # # Check params # @@ -369,7 +944,7 @@ def _apply_val(key, weight, index, new_value): return _err_msg + "
" if not allow_overwrite: if os.path.exists(output_file): - _err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped." + _err_msg = "WARNING: output_file already exists. overwrite not allowed. skipped." print(_err_msg) return _err_msg + "
" @@ -385,13 +960,18 @@ def _apply_val(key, weight, index, new_value): print(f" skip ids : {skip_position_ids} : 0:None, 1:Skip, 2:Reset") result, ret_message = merge( - weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1, - allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, + weight_A=weight_A, + weight_B=weight_B, + model_0=model_0, + model_1=model_1, + allow_overwrite=allow_overwrite, + base_alpha=base_alpha, + output_file=output_file, verbose=verbose, save_as_safetensors=save_as_safetensors, save_as_half=save_as_half, skip_position_ids=skip_position_ids, - ) + ) if result: ret_html = f"merged. {model_0} + {model_1} = {model_O}
" print("merged.") @@ -399,15 +979,18 @@ def _apply_val(key, weight, index, new_value): ret_html = ret_message print("merge failed.") - # save log to history.tsv sd_models.list_models() model_A_info = sd_models.get_closet_checkpoint_match(model_0) model_B_info = sd_models.get_closet_checkpoint_match(model_1) - model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(output_file)) + model_O_info = sd_models.get_closet_checkpoint_match( + os.path.basename(output_file) + ) if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: - model_O_info:CheckpointInfo = model_O_info - model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) + model_O_info: CheckpointInfo = model_O_info + model_O_info.sha256 = hashes.sha256( + model_O_info.filename, "checkpoint/" + model_O_info.title + ) _names = presetWeights.find_names_by_weight(weight_B) if _names and len(_names) > 0: weight_name = _names[0] @@ -416,94 +999,222 @@ def _apply_val(key, weight, index, new_value): def model_name(model_info): return model_info.name if hasattr(model_info, "name") else model_info.title + def model_sha256(model_info): return model_info.sha256 if hasattr(model_info, "sha256") else "" + mergeHistory.add_history( - model_name(model_A_info), - model_A_info.hash, - model_sha256(model_A_info), - model_name(model_B_info), - model_B_info.hash, - model_sha256(model_B_info), - model_name(model_O_info), - model_O_info.hash, - model_sha256(model_O_info), - base_alpha, - weight_A, - weight_B, - weight_name - ) + model_name(model_A_info), + model_A_info.hash, + model_sha256(model_A_info), + model_name(model_B_info), + model_B_info.hash, + model_sha256(model_B_info), + model_name(model_O_info), + model_O_info.hash, + model_sha256(model_O_info), + base_alpha, + weight_A, + weight_B, + weight_name, + ) return ret_html btn_clear_weighted.click( - fn=lambda: [gr.update(value=0.5) for _ in range(25*2)], + fn=lambda: [gr.update(value=0.5) for _ in range(25 * 2)], inputs=[], outputs=[ - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, sl_M_A_00, - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, sl_M_B_00, - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, - ] + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + ], ) def on_change_dd_preset_weight(dd_preset_weight): _weights = presetWeights.find_weight_by_name(dd_preset_weight) _ret = on_btn_apply_block_weight_from_txt(_weights) return [gr.update(value=_weights)] + _ret + dd_preset_weight.change( fn=on_change_dd_preset_weight, inputs=[dd_preset_weight], outputs=[ txt_block_weight, - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, sl_M_A_00, - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, sl_M_B_00, - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, - ] + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + ], ) def on_btn_reload_checkpoint_mbw(): sd_models.list_models() - return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] + return [ + gr.update(choices=sd_models.checkpoint_tiles()), + gr.update(choices=sd_models.checkpoint_tiles()), + ] + btn_reload_checkpoint_mbw.click( - fn=on_btn_reload_checkpoint_mbw, - inputs=[], - outputs=[dd_model_A, dd_model_B] + fn=on_btn_reload_checkpoint_mbw, inputs=[], outputs=[dd_model_A, dd_model_B] ) def on_btn_apply_block_weight_from_txt(txt_block_weight): if not txt_block_weight or txt_block_weight == "": - return [gr.update() for _ in range(25*2)] + return [gr.update() for _ in range(25 * 2)] _list = [x.strip() for x in txt_block_weight.split(",")] - if(len(_list) != 25): - return [gr.update() for _ in range(25*2)] - return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list] + if len(_list) != 25: + return [gr.update() for _ in range(25 * 2)] + return [gr.update(value=str(1 - float(x))) for x in _list] + [ + gr.update(value=x) for x in _list + ] + btn_apply_block_weithg_from_txt.click( fn=on_btn_apply_block_weight_from_txt, inputs=[txt_block_weight], outputs=[ - sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, - sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, + sl_IN_A_00, + sl_IN_A_01, + sl_IN_A_02, + sl_IN_A_03, + sl_IN_A_04, + sl_IN_A_05, + sl_IN_A_06, + sl_IN_A_07, + sl_IN_A_08, + sl_IN_A_09, + sl_IN_A_10, + sl_IN_A_11, sl_M_A_00, - sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, - sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, - sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, - sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, + sl_OUT_A_00, + sl_OUT_A_01, + sl_OUT_A_02, + sl_OUT_A_03, + sl_OUT_A_04, + sl_OUT_A_05, + sl_OUT_A_06, + sl_OUT_A_07, + sl_OUT_A_08, + sl_OUT_A_09, + sl_OUT_A_10, + sl_OUT_A_11, + sl_IN_B_00, + sl_IN_B_01, + sl_IN_B_02, + sl_IN_B_03, + sl_IN_B_04, + sl_IN_B_05, + sl_IN_B_06, + sl_IN_B_07, + sl_IN_B_08, + sl_IN_B_09, + sl_IN_B_10, + sl_IN_B_11, sl_M_B_00, - sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, - sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, - ] + sl_OUT_B_00, + sl_OUT_B_01, + sl_OUT_B_02, + sl_OUT_B_03, + sl_OUT_B_04, + sl_OUT_B_05, + sl_OUT_B_06, + sl_OUT_B_07, + sl_OUT_B_08, + sl_OUT_B_09, + sl_OUT_B_10, + sl_OUT_B_11, + ], ) diff --git a/scripts/mbw_util/merge_history.py b/scripts/mbw_util/merge_history.py index 6bad8ea..d3edf21 100644 --- a/scripts/mbw_util/merge_history.py +++ b/scripts/mbw_util/merge_history.py @@ -12,14 +12,28 @@ CSV_FILE_ROOT = "csv/" CSV_FILE_PATH = "csv/history.tsv" HEADERS = [ - "model_A", "model_A_hash", "model_A_sha256", - "model_B", "model_B_hash", "model_B_sha256", - "model_O", "model_O_hash", "model_O_sha256", - "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"] + "model_A", + "model_A_hash", + "model_A_sha256", + # + "model_B", + "model_B_hash", + "model_B_sha256", + # + "model_O", + "model_O_hash", + "model_O_sha256", + # + "base_alpha", + "weight_name", + "weight_values", + "weight_values2", + "datetime", +] path_root = scripts.basedir() -class MergeHistory(): +class MergeHistory: def __init__(self): self.fileroot = os.path.join(path_root, CSV_FILE_ROOT) self.filepath = os.path.join(path_root, CSV_FILE_PATH) @@ -28,39 +42,55 @@ def __init__(self): if os.path.exists(self.filepath): self.update_header() - def add_history(self, - model_A_name, model_A_hash, model_A_sha256, - model_B_name, model_B_hash, model_B_sha256, - model_O_name, model_O_hash, model_O_sha256, - sl_base_alpha, - weight_value_A, - weight_value_B, - weight_name=""): + def add_history( + self, + model_A_name, + model_A_hash, + model_A_sha256, + # + model_B_name, + model_B_hash, + model_B_sha256, + # + model_O_name, + model_O_hash, + model_O_sha256, + # + sl_base_alpha, + weight_value_A, + weight_value_B, + weight_name="", + ): _history_dict = {} - _history_dict.update({ - "model_A": model_A_name, - "model_A_hash": model_A_hash, - "model_A_sha256": model_A_sha256, - "model_B": model_B_name, - "model_B_hash": model_B_hash, - "model_B_sha256": model_B_sha256, - "model_O": model_O_name, - "model_O_hash": model_O_hash, - "model_O_sha256": model_O_sha256, - "base_alpha": sl_base_alpha, - "weight_name": weight_name, - "weight_values": weight_value_A, - "weight_values2": weight_value_B, - "datetime": f"{datetime.datetime.now()}" - }) + _history_dict.update( + { + "model_A": model_A_name, + "model_A_hash": model_A_hash, + "model_A_sha256": model_A_sha256, + # + "model_B": model_B_name, + "model_B_hash": model_B_hash, + "model_B_sha256": model_B_sha256, + # + "model_O": model_O_name, + "model_O_hash": model_O_hash, + "model_O_sha256": model_O_sha256, + # + "base_alpha": sl_base_alpha, + "weight_name": weight_name, + "weight_values": weight_value_A, + "weight_values2": weight_value_B, + "datetime": f"{datetime.datetime.now()}", + } + ) if not os.path.exists(self.filepath): with open(self.filepath, "w", newline="", encoding="utf-8") as f: - dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + dw = DictWriter(f, fieldnames=HEADERS, delimiter="\t") dw.writeheader() # save to file - with open(self.filepath, "a", newline="", encoding='utf-8') as f: - dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + with open(self.filepath, "a", newline="", encoding="utf-8") as f: + dw = DictWriter(f, fieldnames=HEADERS, delimiter="\t") dw.writerow(_history_dict) def update_header(self): @@ -68,16 +98,16 @@ def update_header(self): if os.path.exists(self.filepath): # check header in case HEADERS updated with open(self.filepath, "r", newline="", encoding="utf-8") as f: - dr = DictReader(f, delimiter='\t') - new_header = [ x for x in HEADERS if x not in dr.fieldnames ] + dr = DictReader(f, delimiter="\t") + new_header = [x for x in HEADERS if x not in dr.fieldnames] if len(new_header) > 0: # need update. - hist_data = [ x for x in dr] + hist_data = [x for x in dr] # apply change if len(hist_data) > 0: # backup before change shutil.copy(self.filepath, self.filepath + ".bak") with open(self.filepath, "w", newline="", encoding="utf-8") as f: - dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') + dw = DictWriter(f, fieldnames=HEADERS, delimiter="\t") dw.writeheader() dw.writerows(hist_data) diff --git a/scripts/mbw_util/preset_weights.py b/scripts/mbw_util/preset_weights.py index 7867f32..dca6872 100644 --- a/scripts/mbw_util/preset_weights.py +++ b/scripts/mbw_util/preset_weights.py @@ -6,14 +6,13 @@ from modules import scripts - CSV_FILE_PATH = "csv/preset.tsv" MYPRESET_PATH = "csv/preset_own.tsv" HEADER = ["preset_name", "preset_weights"] path_root = scripts.basedir() -class PresetWeights(): +class PresetWeights: def __init__(self): self.presets = {} @@ -22,14 +21,18 @@ def __init__(self): reader = DictReader(f, delimiter="\t") lines_dict = [row for row in reader] for line_dict in lines_dict: - _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")]) + _w = ",".join( + [f"{x.strip()}" for x in line_dict["preset_weights"].split(",")] + ) self.presets.update({line_dict["preset_name"]: _w}) with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f: reader = DictReader(f, delimiter="\t") lines_dict = [row for row in reader] for line_dict in lines_dict: - _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")]) + _w = ",".join( + [f"{x.strip()}" for x in line_dict["preset_weights"].split(",")] + ) self.presets.update({line_dict["preset_name"]: _w}) def get_preset_name_list(self): diff --git a/scripts/merge_block_weighted_extension.py b/scripts/merge_block_weighted_extension.py index e7a6fe5..8868649 100644 --- a/scripts/merge_block_weighted_extension.py +++ b/scripts/merge_block_weighted_extension.py @@ -5,12 +5,8 @@ # 2022/12/14 bbc_mc # -import os import gradio as gr - from modules import script_callbacks - - from scripts.mbw import ui_mbw from scripts.mbw_each import ui_mbw_each @@ -28,7 +24,8 @@ def on_ui_tabs(): ui_mbw_each.on_ui_tabs() # return required as (gradio_component, title, elem_id) - return (main_block, "Merge Block Weighted", "merge_block_weighted"), + return ((main_block, "Merge Block Weighted", "merge_block_weighted"),) + # on_UI script_callbacks.on_ui_tabs(on_ui_tabs)