Skip to content

Commit

Permalink
Allow the override of samplers and upscalers by configuration ones
Browse files Browse the repository at this point in the history
  • Loading branch information
Danamir committed Nov 11, 2023
1 parent 5788abb commit 7add6ef
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
3 changes: 3 additions & 0 deletions configs/config.json-dist
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
1.5,
2
],
"override_hr_upscalers": "false",
"hr_upscalers": [
"Latent (bicubic)",
"R-ESRGAN 4x+"
Expand All @@ -48,8 +49,10 @@
"pidinet_sketch",
"pidinet_scribble"
],
"override_samplers": "false",
"samplers": [
"DPM++ 2M Karras",
"DPM++ 2M SDE Karras",
"DDIM",
"Euler",
"Euler a",
Expand Down
11 changes: 10 additions & 1 deletion scripts/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def update_samplers(self):
"""
Update samplers list from available samplers.
"""
if self.configuration["config"].get("override_samplers", 'false') == 'true' and self.configuration["config"].get("samplers", []):
self.samplers["list"] = self.configuration["config"].get("samplers", ["DDIM"])
self.samplers["sampler"] = self.samplers["list"][0]
return

def get_sampler_priority(smp):
priorities = {
Expand Down Expand Up @@ -127,10 +131,15 @@ def update_upscalers(self):
"""
Update upscalers list from available upscalers.
"""
if self.configuration["config"].get("override_hr_upscalers", 'false') == 'true' and self.configuration["config"].get("hr_upscalers", []):
self.render["hr_upscalers"] = self.configuration["config"].get("hr_upscalers", ['Latent (bicubic)'])
self.render["hr_upscaler"] = self.render["hr_upscalers"][0]
return

upscalers_data = self.api.get_upscalers()
upscalers_names = list(map(lambda x: x["name"], upscalers_data))
if len(upscalers_names) == 0:
upscalers_names = self.configuration["config"].get("upscalers", ["Latent (bicubic)"])
upscalers_names = self.configuration["config"].get("hr_upscalers", ["Latent (bicubic)"])
upscalers_names.sort()
self.render["hr_upscalers"] = upscalers_names
self.render["hr_upscaler"] = self.render["hr_upscalers"][0]
Expand Down

0 comments on commit 7add6ef

Please sign in to comment.