From ea243682aa9266219839d4c4e11cd14e1fe139a9 Mon Sep 17 00:00:00 2001 From: db0 Date: Wed, 11 Sep 2024 15:58:17 +0200 Subject: [PATCH] feat: support for flux --- CHANGELOG.md | 4 ++++ horde/apis/v2/stable.py | 5 +++++ horde/classes/stable/processing_generation.py | 3 +++ horde/classes/stable/waiting_prompt.py | 16 +++++++--------- horde/consts.py | 2 +- horde/exceptions.py | 1 + horde/model_reference.py | 1 + 7 files changed, 22 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 188c7032..b720d3ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-or-later # Changelog +# 4.42.0 + +* Adds support for the Flux family of models + # 4.41.0 * Adds support for extra backends behind LLM bridges, and for knowing which are validated. diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 813720b2..eca61d8e 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -172,6 +172,11 @@ def validate(self): if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models): if "control_type" in self.params: raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch") + if any(model_reference.get_model_baseline(model_name).startswith("flux.1") for model_name in self.args.models): + if "control_type" in self.params: + raise e.BadRequest("ControlNet does not work with Flux currently.", rc="ControlNetMismatch") + if self.params.get("hires_fix", False) is True: + raise e.BadRequest("HiRes Fix does not work with Flux currently.", rc="HiResMismatch") if "loras" in self.params: if len(self.params["loras"]) > 5: raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras") diff --git a/horde/classes/stable/processing_generation.py b/horde/classes/stable/processing_generation.py index efc70c5f..a95346af 100644 --- a/horde/classes/stable/processing_generation.py +++ b/horde/classes/stable/processing_generation.py @@ -64,6 +64,9 @@ def get_gen_kudos(self): if self.wp.params.get("hires_fix", False): return self.wp.kudos * 7 return self.wp.kudos * 4 + if model_reference.get_model_baseline(self.model) in ["flux.1"]: + # Flux is double the size of SDXL and much slower, so it gives double the rewards from it. + return self.wp.kudos * 8 return self.wp.kudos def log_aborted_generation(self): diff --git a/horde/classes/stable/waiting_prompt.py b/horde/classes/stable/waiting_prompt.py index e82b6e9c..69f380a6 100644 --- a/horde/classes/stable/waiting_prompt.py +++ b/horde/classes/stable/waiting_prompt.py @@ -364,7 +364,7 @@ def require_upfront_kudos(self, counted_totals, total_threads): max_res = 768 # We allow everyone to use SDXL up to 1024 if max_res < 1024 and any( - model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names + model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade", "flux.1"] for mn in model_names ): max_res = 1024 if max_res > 1024: @@ -372,10 +372,8 @@ def require_upfront_kudos(self, counted_totals, total_threads): # Using more than 10 steps with LCM requires upfront kudos if self.is_using_lcm() and self.get_accurate_steps() > 10: return (True, max_res, False) - # Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse. - if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30: - return (True, max_res, False) - if self.get_accurate_steps() > 50: + # Some models don't require a lot of steps, so we check their requirements. The max steps we allow without upfront kudos is 40 + if any(model_reference.get_model_requirements(mn).get("max_steps",40) > self.get_accurate_steps() for mn in model_names): return (True, max_res, False) if self.width * self.height > max_res * max_res: return (True, max_res, False) @@ -400,9 +398,7 @@ def downgrade(self, max_resolution): # Break, just in case we went too low if self.width * self.height < 512 * 512: break - max_steps = 50 - if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in self.get_model_names()): - max_steps = 30 + max_steps = min(model_reference.get_model_requirements(mn).get("max_steps",30) for mn in self.get_model_names()) if self.params.get("control_type"): max_steps = 20 if self.is_using_lcm(): @@ -435,7 +431,7 @@ def get_accurate_steps(self): if self.params.get("sampler_name", "k_euler_a") in ["k_dpm_adaptive"]: # This sampler chooses the steps amount automatically # and disregards the steps value from the user - # so we just calculate it as an average 50 steps + # so we just calculate it as an average 40 steps return 40 steps = self.params["steps"] if self.params.get("sampler_name", "k_euler_a") in SECOND_ORDER_SAMPLERS: @@ -499,6 +495,8 @@ def extrapolate_dry_run_kudos(self): return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1 if model_reference.get_model_baseline(model_name) in ["stable_cascade"]: return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1 + if model_reference.get_model_baseline(model_name) in ["flux.1"]: + return (self.calculate_extra_kudos_burn(kudos) * self.n * 8) + 1 # The +1 is the extra kudos burn per request return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1 diff --git a/horde/consts.py b/horde/consts.py index f833db9c..60da977e 100644 --- a/horde/consts.py +++ b/horde/consts.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later -HORDE_VERSION = "4.41.0 " +HORDE_VERSION = "4.42.0 " WHITELISTED_SERVICE_IPS = { "212.227.227.178", # Turing Bot diff --git a/horde/exceptions.py b/horde/exceptions.py index 0f107dae..d67cca6c 100644 --- a/horde/exceptions.py +++ b/horde/exceptions.py @@ -150,6 +150,7 @@ "InvalidTransparencyModel", "InvalidTransparencyImg2Img", "InvalidTransparencyCN", + "HiResMismatch", ] diff --git a/horde/model_reference.py b/horde/model_reference.py index 79604913..3d6908dd 100644 --- a/horde/model_reference.py +++ b/horde/model_reference.py @@ -53,6 +53,7 @@ def call_function(self): "stable diffusion 2 512", "stable_diffusion_xl", "stable_cascade", + "flux.1", }: self.stable_diffusion_names.add(model) if self.reference[model].get("nsfw"):