From 76b09b56dc8bf9c5c244c2075e77dd3eb36c3312 Mon Sep 17 00:00:00 2001 From: Danamir Date: Tue, 9 May 2023 18:26:28 +0200 Subject: [PATCH 1/3] Display checkpoint in current configuration --- Scripts/SdPaint.py | 56 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/Scripts/SdPaint.py b/Scripts/SdPaint.py index 098cef7..886584f 100644 --- a/Scripts/SdPaint.py +++ b/Scripts/SdPaint.py @@ -91,6 +91,7 @@ def update_config(config_file, write=False, values=None): presets = load_config(presets_file) settings = {} +webui_config = {} # Setup url = config.get('url', 'http://127.0.0.1:7860') @@ -98,6 +99,8 @@ def update_config(config_file, write=False, values=None): ACCEPTED_FILE_TYPES = ["png", "jpg", "jpeg", "bmp"] # Global variables +checkpoint_pattern = re.compile(r'^(?P.*(?:\\|\/))?(?P.*?)(?P\.vae)?(?P\.safetensors|\.pt|\.ckpt) ?(?P\[[^\]]*\])?.*') + img2img = None img2img_waiting = False img2img_time_prev = None @@ -1272,18 +1275,64 @@ def controlnet_detect(): osd(text=f"Error code returned: HTTP {response.status_code}") +def request_configuration(): + """ + Request current configuration from the webui API. + :return: The configuration JSON. + """ + + response = requests.get(url=f'{url}/sdapi/v1/options') + if response.status_code == 200: + r = response.json() + return r + else: + return {} + + +def ckpt_name(name, display_dir=False, display_ext=False, display_hash=False): + """ + + :param name: + :param remove_dir: + :return: + """ + + replace = '' + if display_dir: + replace += r'\g' + + replace += r'\g' + + if display_ext: + replace += r'\g\g' + + if display_hash: + replace += r' \g' + + return checkpoint_pattern.sub(replace, name) + + def display_configuration(wrap=True): """ Display configuration on screen. :param bool wrap: Wrap long text. """ + global webui_config + + webui_config = request_configuration() + + sd_model_checkpoint = ckpt_name(webui_config['sd_model_checkpoint']) + sd_vae = ckpt_name(webui_config['sd_vae']) + fields = [ '--Prompt', 'prompt', 'negative_prompt', 'seed', '--Render', + 'sd_model_checkpoint', + 'sd_vae', 'settings.steps', 'settings.cfg_scale', 'hr_scale', @@ -1319,7 +1368,7 @@ def display_configuration(wrap=True): if '.' in field: field = field.split('.') - var = globals().get(field[0], None) + var = globals().get(field[0], locals().get(field[0], None)) if var is None: continue @@ -1333,9 +1382,10 @@ def display_configuration(wrap=True): label = field[1] value = getattr(var, field[1]) else: - if globals().get(field, None) is not None: + var = globals().get(field, locals().get(field, None)) + if var is not None: label = field - value = globals().get(field) + value = var if label and value is not None: value = str(value) From 5aef6357c11c473ff5facc4e4f84262b30af8f3f Mon Sep 17 00:00:00 2001 From: Danamir Date: Wed, 10 May 2023 11:05:12 +0200 Subject: [PATCH 2/3] Fetch configuration from the webui API, store it in current State --- scripts/common/state.py | 13 +++++++++++- scripts/common/utils.py | 45 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/scripts/common/state.py b/scripts/common/state.py index ba7029d..54ef128 100644 --- a/scripts/common/state.py +++ b/scripts/common/state.py @@ -1,5 +1,5 @@ import json -from .utils import load_config, update_size +from .utils import load_config, update_size, fetch_configuration class State: @@ -10,6 +10,7 @@ class State: configuration = { "config_file": "config.json", "config": {}, + "webui_config": {} } presets = { "presets_file": "presets.json", @@ -21,6 +22,8 @@ class State: "busy": False, } render = { + "checkpoint": None, + "vae": None, "hr_scales": [], "hr_scale": 1.0, "hr_scale_prev": 1.25, @@ -175,6 +178,14 @@ def update_settings(self): with open(self.json_file, "w") as f: json.dump(settings, f, indent=4) + def update_webui_config(self): + """ + Update webui configuration from the API. + """ + self.configuration["webui_config"] = fetch_configuration(self) + self.render['checkpoint'] = self.configuration["webui_config"]['sd_model_checkpoint'] + self.render['vae'] = self.configuration["webui_config"]['sd_vae'] + def __setitem__(self, key, value): setattr(self, key, value) diff --git a/scripts/common/utils.py b/scripts/common/utils.py index 29e3375..e325d4a 100644 --- a/scripts/common/utils.py +++ b/scripts/common/utils.py @@ -2,6 +2,7 @@ import functools import os import random +import re import shutil import threading import base64 @@ -9,6 +10,8 @@ import json import time import math + +import requests from PIL import Image from psd_tools import PSDImage @@ -312,5 +315,47 @@ def get_img2img_json(state): return json_data +def fetch_configuration(state): + """ + Request current configuration from the webui API. + :return: The configuration JSON. + """ + + response = requests.get(url=f'{state.server["url"]}/sdapi/v1/options') + if response.status_code == 200: + r = response.json() + return r + else: + return {} + + +checkpoint_pattern = re.compile(r'^(?P.*(?:\\|\/))?(?P.*?)(?P\.vae)?(?P\.safetensors|\.pt|\.ckpt) ?(?P\[[^\]]*\])?.*') + + +def ckpt_name(name, display_dir=False, display_ext=False, display_hash=False): + """ + Clean checkpoint name. + :param str name: Checkpoint name. + :param bool display_dir: Display full path. + :param bool display_ext: Display checkpoint extension. + :param bool display_hash: Display checkpoint hash. + :return: Cleaned checkpoint name. + """ + + replace = '' + if display_dir: + replace += r'\g' + + replace += r'\g' + + if display_ext: + replace += r'\g\g' + + if display_hash: + replace += r' \g' + + return checkpoint_pattern.sub(replace, name) + + # Type hinting imports: from .state import State From 6dd65990b97663f8ce33c8119a0adbca862470e1 Mon Sep 17 00:00:00 2001 From: Danamir Date: Wed, 10 May 2023 11:05:56 +0200 Subject: [PATCH 3/3] Display webui selected checkpoint & vae in display configuration OSD --- scripts/views/PygameView.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/views/PygameView.py b/scripts/views/PygameView.py index e26cb49..1c8734e 100644 --- a/scripts/views/PygameView.py +++ b/scripts/views/PygameView.py @@ -14,7 +14,7 @@ from PIL import Image, ImageOps import tkinter as tk from tkinter import filedialog, simpledialog -from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed +from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed, ckpt_name from scripts.common.cn_requests import fetch_controlnet_models, progress_request, fetch_detect_image, fetch_img2img, post_request from scripts.common.output_files_utils import autosave_image, save_image from scripts.common.state import State @@ -715,12 +715,16 @@ def display_configuration(self, wrap=True): :param bool wrap: Wrap long text. """ + self.state.update_webui_config() + fields = [ '--Prompt', 'state/gen_settings/prompt', 'state/gen_settings/negative_prompt', 'state/gen_settings/seed', '--Render', + 'state/render/checkpoint', + 'state/render/vae', 'state/render/render_size', 'settings.steps', 'settings.cfg_scale', @@ -733,7 +737,6 @@ def display_configuration(self, wrap=True): 'state/control_net/controlnet_weight', 'state/control_net/controlnet_guidance_end', 'state/render/pixel_perfect', - '--Misc', 'state/detectors/detector' ] @@ -759,7 +762,7 @@ def display_configuration(self, wrap=True): if '.' in field: field = field.split('.') - var = globals().get(field[0], None) + var = globals().get(field[0], locals().get(field[0], None)) if var is None: continue @@ -779,19 +782,22 @@ def display_configuration(self, wrap=True): field_value = getattr(self.state, field_components[0])[field_components[1]] else: label = field - field_value = globals().get(field, None) - - if 'size' in label and isinstance(field_value, tuple) and len(field_value) == 2: - field_value = f"{field_value[0]}x{field_value[1]}" + field_value = globals().get(field, locals().get(field, None)) if field_value is not None: value = field_value if label and value is not None: - value = str(value) + # prettify label = label.replace('_', ' ') if label.endswith('prompt'): value = value.replace(', ', ',').replace(',', ', ') # nicer prompt display + elif 'size' in label and isinstance(value, tuple) and len(value) == 2: + value = f"{value[0]}x{value[1]}" + elif label in ('checkpoint', 'vae'): + value = ckpt_name(value) + else: + value = str(value) # wrap text if wrap and len(value) > wrap: @@ -1180,9 +1186,9 @@ def main(self): if self.shift_down: # cycle detectors self.state.detectors["detector"] = self.state.detectors["list"][(self.state.detectors["list"].index(self.state.detectors["detector"])+1) % len(self.state.detectors["list"])] - self.osd(text=f"ControlNet detector: {self.state.detectors['detector']}") + self.osd(text=f"ControlNet detector: {self.state.detectors['detector'].replace('_', ' ')}") else: - self.osd(text=f"Detect {self.state.detectors['detector']}") + self.osd(text=f"Detect {self.state.detectors['detector'].replace('_', ' ')}") detector = str(self.state.detectors['detector']) t = threading.Thread(target=functools.partial(self.controlnet_detect, detector))