diff --git a/scripts/common/state.py b/scripts/common/state.py index ba7029d..54ef128 100644 --- a/scripts/common/state.py +++ b/scripts/common/state.py @@ -1,5 +1,5 @@ import json -from .utils import load_config, update_size +from .utils import load_config, update_size, fetch_configuration class State: @@ -10,6 +10,7 @@ class State: configuration = { "config_file": "config.json", "config": {}, + "webui_config": {} } presets = { "presets_file": "presets.json", @@ -21,6 +22,8 @@ class State: "busy": False, } render = { + "checkpoint": None, + "vae": None, "hr_scales": [], "hr_scale": 1.0, "hr_scale_prev": 1.25, @@ -175,6 +178,14 @@ def update_settings(self): with open(self.json_file, "w") as f: json.dump(settings, f, indent=4) + def update_webui_config(self): + """ + Update webui configuration from the API. + """ + self.configuration["webui_config"] = fetch_configuration(self) + self.render['checkpoint'] = self.configuration["webui_config"]['sd_model_checkpoint'] + self.render['vae'] = self.configuration["webui_config"]['sd_vae'] + def __setitem__(self, key, value): setattr(self, key, value) diff --git a/scripts/common/utils.py b/scripts/common/utils.py index 29e3375..e325d4a 100644 --- a/scripts/common/utils.py +++ b/scripts/common/utils.py @@ -2,6 +2,7 @@ import functools import os import random +import re import shutil import threading import base64 @@ -9,6 +10,8 @@ import json import time import math + +import requests from PIL import Image from psd_tools import PSDImage @@ -312,5 +315,47 @@ def get_img2img_json(state): return json_data +def fetch_configuration(state): + """ + Request current configuration from the webui API. + :return: The configuration JSON. + """ + + response = requests.get(url=f'{state.server["url"]}/sdapi/v1/options') + if response.status_code == 200: + r = response.json() + return r + else: + return {} + + +checkpoint_pattern = re.compile(r'^(?P.*(?:\\|\/))?(?P.*?)(?P\.vae)?(?P\.safetensors|\.pt|\.ckpt) ?(?P\[[^\]]*\])?.*') + + +def ckpt_name(name, display_dir=False, display_ext=False, display_hash=False): + """ + Clean checkpoint name. + :param str name: Checkpoint name. + :param bool display_dir: Display full path. + :param bool display_ext: Display checkpoint extension. + :param bool display_hash: Display checkpoint hash. + :return: Cleaned checkpoint name. + """ + + replace = '' + if display_dir: + replace += r'\g' + + replace += r'\g' + + if display_ext: + replace += r'\g\g' + + if display_hash: + replace += r' \g' + + return checkpoint_pattern.sub(replace, name) + + # Type hinting imports: from .state import State diff --git a/scripts/views/PygameView.py b/scripts/views/PygameView.py index e26cb49..1c8734e 100644 --- a/scripts/views/PygameView.py +++ b/scripts/views/PygameView.py @@ -14,7 +14,7 @@ from PIL import Image, ImageOps import tkinter as tk from tkinter import filedialog, simpledialog -from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed +from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed, ckpt_name from scripts.common.cn_requests import fetch_controlnet_models, progress_request, fetch_detect_image, fetch_img2img, post_request from scripts.common.output_files_utils import autosave_image, save_image from scripts.common.state import State @@ -715,12 +715,16 @@ def display_configuration(self, wrap=True): :param bool wrap: Wrap long text. """ + self.state.update_webui_config() + fields = [ '--Prompt', 'state/gen_settings/prompt', 'state/gen_settings/negative_prompt', 'state/gen_settings/seed', '--Render', + 'state/render/checkpoint', + 'state/render/vae', 'state/render/render_size', 'settings.steps', 'settings.cfg_scale', @@ -733,7 +737,6 @@ def display_configuration(self, wrap=True): 'state/control_net/controlnet_weight', 'state/control_net/controlnet_guidance_end', 'state/render/pixel_perfect', - '--Misc', 'state/detectors/detector' ] @@ -759,7 +762,7 @@ def display_configuration(self, wrap=True): if '.' in field: field = field.split('.') - var = globals().get(field[0], None) + var = globals().get(field[0], locals().get(field[0], None)) if var is None: continue @@ -779,19 +782,22 @@ def display_configuration(self, wrap=True): field_value = getattr(self.state, field_components[0])[field_components[1]] else: label = field - field_value = globals().get(field, None) - - if 'size' in label and isinstance(field_value, tuple) and len(field_value) == 2: - field_value = f"{field_value[0]}x{field_value[1]}" + field_value = globals().get(field, locals().get(field, None)) if field_value is not None: value = field_value if label and value is not None: - value = str(value) + # prettify label = label.replace('_', ' ') if label.endswith('prompt'): value = value.replace(', ', ',').replace(',', ', ') # nicer prompt display + elif 'size' in label and isinstance(value, tuple) and len(value) == 2: + value = f"{value[0]}x{value[1]}" + elif label in ('checkpoint', 'vae'): + value = ckpt_name(value) + else: + value = str(value) # wrap text if wrap and len(value) > wrap: @@ -1180,9 +1186,9 @@ def main(self): if self.shift_down: # cycle detectors self.state.detectors["detector"] = self.state.detectors["list"][(self.state.detectors["list"].index(self.state.detectors["detector"])+1) % len(self.state.detectors["list"])] - self.osd(text=f"ControlNet detector: {self.state.detectors['detector']}") + self.osd(text=f"ControlNet detector: {self.state.detectors['detector'].replace('_', ' ')}") else: - self.osd(text=f"Detect {self.state.detectors['detector']}") + self.osd(text=f"Detect {self.state.detectors['detector'].replace('_', ' ')}") detector = str(self.state.detectors['detector']) t = threading.Thread(target=functools.partial(self.controlnet_detect, detector))