Skip to content

Commit

Permalink
feat: support for flux
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Sep 11, 2024
1 parent 4c0b12e commit ea24368
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.42.0

* Adds support for the Flux family of models

# 4.41.0

* Adds support for extra backends behind LLM bridges, and for knowing which are validated.
Expand Down
5 changes: 5 additions & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def validate(self):
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("flux.1") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Flux currently.", rc="ControlNetMismatch")
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("HiRes Fix does not work with Flux currently.", rc="HiResMismatch")
if "loras" in self.params:
if len(self.params["loras"]) > 5:
raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras")
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def get_gen_kudos(self):
if self.wp.params.get("hires_fix", False):
return self.wp.kudos * 7
return self.wp.kudos * 4
if model_reference.get_model_baseline(self.model) in ["flux.1"]:
# Flux is double the size of SDXL and much slower, so it gives double the rewards from it.
return self.wp.kudos * 8
return self.wp.kudos

def log_aborted_generation(self):
Expand Down
16 changes: 7 additions & 9 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,16 @@ def require_upfront_kudos(self, counted_totals, total_threads):
max_res = 768
# We allow everyone to use SDXL up to 1024
if max_res < 1024 and any(
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade", "flux.1"] for mn in model_names
):
max_res = 1024
if max_res > 1024:
max_res = 1024
# Using more than 10 steps with LCM requires upfront kudos
if self.is_using_lcm() and self.get_accurate_steps() > 10:
return (True, max_res, False)
# Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse.
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30:
return (True, max_res, False)
if self.get_accurate_steps() > 50:
# Some models don't require a lot of steps, so we check their requirements. The max steps we allow without upfront kudos is 40
if any(model_reference.get_model_requirements(mn).get("max_steps",40) > self.get_accurate_steps() for mn in model_names):
return (True, max_res, False)
if self.width * self.height > max_res * max_res:
return (True, max_res, False)
Expand All @@ -400,9 +398,7 @@ def downgrade(self, max_resolution):
# Break, just in case we went too low
if self.width * self.height < 512 * 512:
break
max_steps = 50
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in self.get_model_names()):
max_steps = 30
max_steps = min(model_reference.get_model_requirements(mn).get("max_steps",30) for mn in self.get_model_names())
if self.params.get("control_type"):
max_steps = 20
if self.is_using_lcm():
Expand Down Expand Up @@ -435,7 +431,7 @@ def get_accurate_steps(self):
if self.params.get("sampler_name", "k_euler_a") in ["k_dpm_adaptive"]:
# This sampler chooses the steps amount automatically
# and disregards the steps value from the user
# so we just calculate it as an average 50 steps
# so we just calculate it as an average 40 steps
return 40
steps = self.params["steps"]
if self.params.get("sampler_name", "k_euler_a") in SECOND_ORDER_SAMPLERS:
Expand Down Expand Up @@ -499,6 +495,8 @@ def extrapolate_dry_run_kudos(self):
return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1
if model_reference.get_model_baseline(model_name) in ["stable_cascade"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1
if model_reference.get_model_baseline(model_name) in ["flux.1"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 8) + 1
# The +1 is the extra kudos burn per request
return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1

Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

HORDE_VERSION = "4.41.0 "
HORDE_VERSION = "4.42.0 "

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
1 change: 1 addition & 0 deletions horde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"InvalidTransparencyModel",
"InvalidTransparencyImg2Img",
"InvalidTransparencyCN",
"HiResMismatch",
]


Expand Down
1 change: 1 addition & 0 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def call_function(self):
"stable diffusion 2 512",
"stable_diffusion_xl",
"stable_cascade",
"flux.1",
}:
self.stable_diffusion_names.add(model)
if self.reference[model].get("nsfw"):
Expand Down

0 comments on commit ea24368

Please sign in to comment.