Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Flux support #450

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.42.0

* Adds support for the Flux family of models

# 4.41.0

* Adds support for extra backends behind LLM bridges, and for knowing which are validated.
Expand Down
5 changes: 5 additions & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def validate(self):
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("flux_1") for model_name in self.args.models):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Flux currently.", rc="ControlNetMismatch")
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("HiRes Fix does not work with Flux currently.", rc="HiResMismatch")
if "loras" in self.params:
if len(self.params["loras"]) > 5:
raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras")
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def get_gen_kudos(self):
if self.wp.params.get("hires_fix", False):
return self.wp.kudos * 7
return self.wp.kudos * 4
if model_reference.get_model_baseline(self.model) in ["flux_1"]:
# Flux is double the size of SDXL and much slower, so it gives double the rewards from it.
return self.wp.kudos * 8
return self.wp.kudos

def log_aborted_generation(self):
Expand Down
24 changes: 15 additions & 9 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from horde.classes.base.waiting_prompt import WaitingPrompt
from horde.classes.stable.kudos import KudosModel
from horde.consts import (
BASELINE_BATCHING_MULTIPLIERS,
HEAVY_POST_PROCESSORS,
KNOWN_LCM_LORA_IDS,
KNOWN_LCM_LORA_VERSIONS,
Expand Down Expand Up @@ -364,18 +365,16 @@ def require_upfront_kudos(self, counted_totals, total_threads):
max_res = 768
# We allow everyone to use SDXL up to 1024
if max_res < 1024 and any(
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade", "flux_1"] for mn in model_names
):
max_res = 1024
if max_res > 1024:
max_res = 1024
# Using more than 10 steps with LCM requires upfront kudos
if self.is_using_lcm() and self.get_accurate_steps() > 10:
return (True, max_res, False)
# Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse.
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30:
return (True, max_res, False)
if self.get_accurate_steps() > 50:
# Some models don't require a lot of steps, so we check their requirements. The max steps we allow without upfront kudos is 40
if any(model_reference.get_model_requirements(mn).get("max_steps", 40) > self.get_accurate_steps() for mn in model_names):
return (True, max_res, False)
if self.width * self.height > max_res * max_res:
return (True, max_res, False)
Expand All @@ -400,9 +399,7 @@ def downgrade(self, max_resolution):
# Break, just in case we went too low
if self.width * self.height < 512 * 512:
break
max_steps = 50
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in self.get_model_names()):
max_steps = 30
max_steps = min(model_reference.get_model_requirements(mn).get("max_steps", 30) for mn in self.get_model_names())
if self.params.get("control_type"):
max_steps = 20
if self.is_using_lcm():
Expand Down Expand Up @@ -435,7 +432,7 @@ def get_accurate_steps(self):
if self.params.get("sampler_name", "k_euler_a") in ["k_dpm_adaptive"]:
# This sampler chooses the steps amount automatically
# and disregards the steps value from the user
# so we just calculate it as an average 50 steps
# so we just calculate it as an average 40 steps
return 40
steps = self.params["steps"]
if self.params.get("sampler_name", "k_euler_a") in SECOND_ORDER_SAMPLERS:
Expand Down Expand Up @@ -499,6 +496,8 @@ def extrapolate_dry_run_kudos(self):
return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1
if model_reference.get_model_baseline(model_name) in ["stable_cascade"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1
if model_reference.get_model_baseline(model_name) in ["flux_1"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 8) + 1
# The +1 is the extra kudos burn per request
return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1

Expand All @@ -513,5 +512,12 @@ def has_heavy_operations(self):
return True
return False

def get_highest_model_batching_multiplier(self):
highest_multiplier = 1
for mn in self.get_model_names():
if BASELINE_BATCHING_MULTIPLIERS.get(mn, 1) > highest_multiplier:
highest_multiplier = BASELINE_BATCHING_MULTIPLIERS.get(mn, 1)
return highest_multiplier

def count_pp(self):
return len(self.params.get("post_processing", []))
1 change: 1 addition & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def get_safe_amount(self, amount, wp):
if wp.has_heavy_operations():
pp_multiplier *= 1.8
mps *= pp_multiplier
mps *= wp.get_highest_model_batching_multiplier()
safe_amount = round(safe_generations / mps)
if safe_amount > amount:
safe_amount = amount
Expand Down
9 changes: 8 additions & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

HORDE_VERSION = "4.41.0 "
HORDE_VERSION = "4.42.0 "

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down Expand Up @@ -38,6 +38,13 @@
"4x_AnimeSharp" "CodeFormers",
}

# These models are very large in VRAM, so we increase the calculated MPS
# used to figure out batches by a set multiplier to reduce how many images are batched
# at a time when these models are used.
BASELINE_BATCHING_MULTIPLIERS = {
"flux_1": 2,
}


KNOWN_SAMPLERS = {
"k_lms",
Expand Down
1 change: 1 addition & 0 deletions horde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"InvalidTransparencyModel",
"InvalidTransparencyImg2Img",
"InvalidTransparencyCN",
"HiResMismatch",
]


Expand Down
1 change: 1 addition & 0 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def call_function(self):
"stable diffusion 2 512",
"stable_diffusion_xl",
"stable_cascade",
"flux_1",
}:
self.stable_diffusion_names.add(model)
if self.reference[model].get("nsfw"):
Expand Down
File renamed without changes.
Binary file added img_stable/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions img_stable/2.jpg.license
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis <[email protected]>

SPDX-License-Identifier: CC0-1.0
Binary file added img_stable/3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions img_stable/3.jpg.license
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis <[email protected]>

SPDX-License-Identifier: CC0-1.0
Loading