diff --git a/scripts/common/state.py b/scripts/common/state.py
index ba7029d..54ef128 100644
--- a/scripts/common/state.py
+++ b/scripts/common/state.py
@@ -1,5 +1,5 @@
import json
-from .utils import load_config, update_size
+from .utils import load_config, update_size, fetch_configuration
class State:
@@ -10,6 +10,7 @@ class State:
configuration = {
"config_file": "config.json",
"config": {},
+ "webui_config": {}
}
presets = {
"presets_file": "presets.json",
@@ -21,6 +22,8 @@ class State:
"busy": False,
}
render = {
+ "checkpoint": None,
+ "vae": None,
"hr_scales": [],
"hr_scale": 1.0,
"hr_scale_prev": 1.25,
@@ -175,6 +178,14 @@ def update_settings(self):
with open(self.json_file, "w") as f:
json.dump(settings, f, indent=4)
+ def update_webui_config(self):
+ """
+ Update webui configuration from the API.
+ """
+ self.configuration["webui_config"] = fetch_configuration(self)
+ self.render['checkpoint'] = self.configuration["webui_config"]['sd_model_checkpoint']
+ self.render['vae'] = self.configuration["webui_config"]['sd_vae']
+
def __setitem__(self, key, value):
setattr(self, key, value)
diff --git a/scripts/common/utils.py b/scripts/common/utils.py
index 29e3375..e325d4a 100644
--- a/scripts/common/utils.py
+++ b/scripts/common/utils.py
@@ -2,6 +2,7 @@
import functools
import os
import random
+import re
import shutil
import threading
import base64
@@ -9,6 +10,8 @@
import json
import time
import math
+
+import requests
from PIL import Image
from psd_tools import PSDImage
@@ -312,5 +315,47 @@ def get_img2img_json(state):
return json_data
+def fetch_configuration(state):
+ """
+ Request current configuration from the webui API.
+ :return: The configuration JSON.
+ """
+
+ response = requests.get(url=f'{state.server["url"]}/sdapi/v1/options')
+ if response.status_code == 200:
+ r = response.json()
+ return r
+ else:
+ return {}
+
+
+checkpoint_pattern = re.compile(r'^(?P
.*(?:\\|\/))?(?P.*?)(?P\.vae)?(?P\.safetensors|\.pt|\.ckpt) ?(?P\[[^\]]*\])?.*')
+
+
+def ckpt_name(name, display_dir=False, display_ext=False, display_hash=False):
+ """
+ Clean checkpoint name.
+ :param str name: Checkpoint name.
+ :param bool display_dir: Display full path.
+ :param bool display_ext: Display checkpoint extension.
+ :param bool display_hash: Display checkpoint hash.
+ :return: Cleaned checkpoint name.
+ """
+
+ replace = ''
+ if display_dir:
+ replace += r'\g'
+
+ replace += r'\g'
+
+ if display_ext:
+ replace += r'\g\g'
+
+ if display_hash:
+ replace += r' \g'
+
+ return checkpoint_pattern.sub(replace, name)
+
+
# Type hinting imports:
from .state import State
diff --git a/scripts/views/PygameView.py b/scripts/views/PygameView.py
index e26cb49..1c8734e 100644
--- a/scripts/views/PygameView.py
+++ b/scripts/views/PygameView.py
@@ -14,7 +14,7 @@
from PIL import Image, ImageOps
import tkinter as tk
from tkinter import filedialog, simpledialog
-from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed
+from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed, ckpt_name
from scripts.common.cn_requests import fetch_controlnet_models, progress_request, fetch_detect_image, fetch_img2img, post_request
from scripts.common.output_files_utils import autosave_image, save_image
from scripts.common.state import State
@@ -715,12 +715,16 @@ def display_configuration(self, wrap=True):
:param bool wrap: Wrap long text.
"""
+ self.state.update_webui_config()
+
fields = [
'--Prompt',
'state/gen_settings/prompt',
'state/gen_settings/negative_prompt',
'state/gen_settings/seed',
'--Render',
+ 'state/render/checkpoint',
+ 'state/render/vae',
'state/render/render_size',
'settings.steps',
'settings.cfg_scale',
@@ -733,7 +737,6 @@ def display_configuration(self, wrap=True):
'state/control_net/controlnet_weight',
'state/control_net/controlnet_guidance_end',
'state/render/pixel_perfect',
- '--Misc',
'state/detectors/detector'
]
@@ -759,7 +762,7 @@ def display_configuration(self, wrap=True):
if '.' in field:
field = field.split('.')
- var = globals().get(field[0], None)
+ var = globals().get(field[0], locals().get(field[0], None))
if var is None:
continue
@@ -779,19 +782,22 @@ def display_configuration(self, wrap=True):
field_value = getattr(self.state, field_components[0])[field_components[1]]
else:
label = field
- field_value = globals().get(field, None)
-
- if 'size' in label and isinstance(field_value, tuple) and len(field_value) == 2:
- field_value = f"{field_value[0]}x{field_value[1]}"
+ field_value = globals().get(field, locals().get(field, None))
if field_value is not None:
value = field_value
if label and value is not None:
- value = str(value)
+ # prettify
label = label.replace('_', ' ')
if label.endswith('prompt'):
value = value.replace(', ', ',').replace(',', ', ') # nicer prompt display
+ elif 'size' in label and isinstance(value, tuple) and len(value) == 2:
+ value = f"{value[0]}x{value[1]}"
+ elif label in ('checkpoint', 'vae'):
+ value = ckpt_name(value)
+ else:
+ value = str(value)
# wrap text
if wrap and len(value) > wrap:
@@ -1180,9 +1186,9 @@ def main(self):
if self.shift_down:
# cycle detectors
self.state.detectors["detector"] = self.state.detectors["list"][(self.state.detectors["list"].index(self.state.detectors["detector"])+1) % len(self.state.detectors["list"])]
- self.osd(text=f"ControlNet detector: {self.state.detectors['detector']}")
+ self.osd(text=f"ControlNet detector: {self.state.detectors['detector'].replace('_', ' ')}")
else:
- self.osd(text=f"Detect {self.state.detectors['detector']}")
+ self.osd(text=f"Detect {self.state.detectors['detector'].replace('_', ' ')}")
detector = str(self.state.detectors['detector'])
t = threading.Thread(target=functools.partial(self.controlnet_detect, detector))