Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allows LLM bridge to handle different backends and inform the horde #449

Merged
merged 11 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.41.0

* Adds support for extra backends behind LLM bridges, and for knowing which are validated.

# 4.40.3

* Ensure jobs don't expire soon after being picked up
Expand Down
10 changes: 9 additions & 1 deletion horde/apis/models/kobold_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, api):
},
)
self.response_model_job_pop = api.model(
"GenerationPayload",
"GenerationPayloadKobold",
{
"payload": fields.Nested(self.response_model_generation_payload, skip_none=True),
"id": fields.String(description="The UUID for this text generation."),
Expand Down Expand Up @@ -294,6 +294,14 @@ def __init__(self, api):
"When False, Evaluating workers will also be used which can increase speed but adds more risk!"
),
),
"validated_backends": fields.Boolean(
default=True,
description=(
f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. "
"When False, non-validated backends will also be used which can increase speed but "
"you may end up with unexpected results."
),
),
"slow_workers": fields.Boolean(
default=True,
description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
Expand Down
8 changes: 8 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,14 @@ def __init__(self, api):
"When False, Evaluating workers will also be used which can increase speed but adds more risk!"
),
),
"validated_backends": fields.Boolean(
default=True,
description=(
f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. "
"When False, non-validated backends will also be used which can increase speed but "
"you may end up with unexpected results."
),
),
"slow_workers": fields.Boolean(
default=True,
description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
Expand Down
10 changes: 10 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def __init__(self):
"When False, Evaluating workers will also be used.",
location="json",
)
self.generate_parser.add_argument(
"validated_backends",
type=bool,
required=False,
default=False,
help=f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. "
"When False, non-validated backends will also be used which can increase speed but "
"you may end up with unexpected results.",
location="json",
)
self.generate_parser.add_argument(
"workers",
type=list,
Expand Down
1 change: 1 addition & 0 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def initiate_waiting_prompt(self):
nsfw=self.args.nsfw,
censor_nsfw=self.args.censor_nsfw,
trusted_workers=self.args.trusted_workers,
validated_backends=self.args.validated_backends,
worker_blacklist=self.args.worker_blacklist,
ipaddr=self.user_ip,
sharedkey_id=self.args.apikey if self.sharedkey else None,
Expand Down
2 changes: 1 addition & 1 deletion horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def initiate_waiting_prompt(self):
params=self.params,
softprompt=self.args.softprompt,
trusted_workers=self.args.trusted_workers,
validated_backends=self.args.validated_backends,
worker_blacklist=self.args.worker_blacklist,
slow_workers=self.args.slow_workers,
ipaddr=self.user_ip,
Expand Down Expand Up @@ -311,7 +312,6 @@ def get_sorted_wp(self, priority_user_ids=None):
priority_user_ids=priority_user_ids,
page=self.wp_page,
)

return sorted_wps


Expand Down
1 change: 1 addition & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def initiate_waiting_prompt(self):
nsfw=self.args.nsfw,
censor_nsfw=self.args.censor_nsfw,
trusted_workers=self.args.trusted_workers,
validated_backends=self.args.validated_backends,
worker_blacklist=self.args.worker_blacklist,
slow_workers=self.args.slow_workers,
source_processing=self.args.source_processing,
Expand Down
8 changes: 8 additions & 0 deletions horde/bridge_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@
},
}

LLM_VALIDATED_BACKENDS = {"AI Horde Worker", "AI Horde Worker~aphrodite~oai", "AI Horde Worker~aphrodite~kai", "KoboldCppEmbedWorker"}


@logger.catch(reraise=True)
def parse_bridge_agent(bridge_agent):
Expand Down Expand Up @@ -195,6 +197,12 @@ def check_bridge_capability(capability, bridge_agent):
return capability in total_capabilities


@logger.catch(reraise=True)
def is_backed_validated(bridge_agent):
bridge_name, _ = parse_bridge_agent(bridge_agent)
return bridge_name in LLM_VALIDATED_BACKENDS


@logger.catch(reraise=True)
def get_supported_samplers(bridge_agent, karras=True):
bridge_name, bridge_version = parse_bridge_agent(bridge_agent)
Expand Down
1 change: 1 addition & 0 deletions horde/classes/base/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class WaitingPrompt(db.Model):
ipaddr = db.Column(db.String(39)) # ipv6
safe_ip = db.Column(db.Boolean, default=False, nullable=False)
trusted_workers = db.Column(db.Boolean, default=False, nullable=False, index=True)
validated_backends = db.Column(db.Boolean, default=True, nullable=False, index=True)
slow_workers = db.Column(db.Boolean, default=True, nullable=False, index=True)
worker_blacklist = db.Column(db.Boolean, default=False, nullable=False, index=True)
faulted = db.Column(db.Boolean, default=False, nullable=False, index=True)
Expand Down
8 changes: 7 additions & 1 deletion horde/classes/kobold/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import os

from horde import vars as hv
from horde.bridge_reference import (
is_backed_validated,
)
from horde.classes.base.processing_generation import ProcessingGeneration
from horde.classes.kobold.genstats import record_text_statistic
from horde.flask import db
Expand Down Expand Up @@ -56,7 +59,10 @@ def get_gen_kudos(self):
# This is the approximate reward for generating with a 2.7 model at 4bit
model_multiplier = model_reference.get_text_model_multiplier(self.model)
parameter_bonus = (max(model_multiplier, 13) / 13) ** 0.20
kudos = self.get_things_count() * parameter_bonus * model_multiplier / 100
kudos = self.get_things_count() * parameter_bonus * model_multiplier / 125
# Unvalidated backends have their rewards cut to 30%
if not is_backed_validated(self.worker.bridge_agent):
kudos *= 0.3
return round(kudos * context_multiplier, 2)

def log_aborted_generation(self):
Expand Down
8 changes: 8 additions & 0 deletions horde/classes/kobold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from horde import exceptions as e
from horde import horde_redis as hr
from horde.bridge_reference import (
is_backed_validated,
)
from horde.classes.base.worker import Worker
from horde.flask import SQLITE_MODE, db
from horde.logger import logger
Expand Down Expand Up @@ -114,6 +117,9 @@ def calculate_uptime_reward(self):
param_multiplier = model_reference.get_text_model_multiplier(model) / 7
if param_multiplier < 0.25:
param_multiplier = 0.25
# Unvalidated backends get less kudos
if not is_backed_validated(self.bridge_agent):
base_kudos *= 0.3
# The uptime is based on both how much context they provide, as well as how many parameters they're serving
return round(base_kudos * param_multiplier, 2)

Expand All @@ -125,6 +131,8 @@ def can_generate(self, waiting_prompt):
return [False, "max_context_length"]
if self.max_length < waiting_prompt.max_length:
return [False, "max_length"]
if waiting_prompt.validated_backends and not is_backed_validated(self.bridge_agent):
return [False, "bridge_version"]
matching_softprompt = True
if waiting_prompt.softprompt:
matching_softprompt = False
Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

HORDE_VERSION = "4.40.3 "
HORDE_VERSION = "4.41.0 "

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
2 changes: 1 addition & 1 deletion horde/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def count_things_for_specific_model(wp_class, procgen_class, model_name):


def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, priority_user_ids=None, page=0):
# This is just the top 25 - Adjusted method to send ImageWorker object. Filters to add.
# This is just the top 3 - Adjusted method to send ImageWorker object. Filters to add.
# TODO: Filter by ImageWorker not in WP.tricked_worker
# TODO: If any word in the prompt is in the WP.blacklist rows, then exclude it (L293 in base.worker.ImageWorker.gan_generate())
PER_PAGE = 3 # how many requests we're picking up to filter further
Expand Down
9 changes: 8 additions & 1 deletion horde/database/text_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import horde.classes.base.stats as stats
from horde import horde_redis as hr
from horde.bridge_reference import (
is_backed_validated,
)
from horde.classes.base.waiting_prompt import WPAllowedWorkers, WPModels
from horde.classes.base.worker import WorkerPerformance
from horde.classes.kobold.processing_generation import TextProcessingGeneration
Expand All @@ -30,7 +33,7 @@ def convert_things_to_kudos(things, **kwargs):


def get_sorted_text_wp_filtered_to_worker(worker, models_list=None, priority_user_ids=None, page=0):
# This is just the top 100 - Adjusted method to send Worker object. Filters to add.
# This is just the top 3 - Adjusted method to send Worker object. Filters to add.
# TODO: Filter by (Worker in WP.workers) __ONLY IF__ len(WP.workers) >=1
# TODO: Filter by WP.trusted_workers == False __ONLY IF__ Worker.user.trusted == False
# TODO: Filter by Worker not in WP.tricked_worker
Expand Down Expand Up @@ -91,6 +94,10 @@ def get_sorted_text_wp_filtered_to_worker(worker, models_list=None, priority_use
worker.maintenance == False, # noqa E712
TextWaitingPrompt.user_id == worker.user_id,
),
or_(
is_backed_validated(worker.bridge_agent),
TextWaitingPrompt.validated_backends.is_(False),
),
)
)
if priority_user_ids:
Expand Down
1 change: 1 addition & 0 deletions sql_statements/4.41.0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE waiting_prompts ADD COLUMN validated_backends BOOLEAN default true;
3 changes: 3 additions & 0 deletions sql_statements/4.41.0.txt.license
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPDX-FileCopyrightText: Konstantinos Thoukydidis <[email protected]>

SPDX-License-Identifier: AGPL-3.0-or-later
9 changes: 7 additions & 2 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
async_dict = {
"prompt": "a horde of cute stable robots in a sprawling server room repairing a massive mainframe",
"trusted_workers": True,
"validated_backends": False,
"max_length": 512,
"max_context_length": 2048,
"temperature": 1,
Expand All @@ -37,9 +38,13 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert pop_req.ok, pop_req.text
pop_results = pop_req.json()
# print(json.dumps(pop_results, indent=4))

job_id = pop_results["id"]
assert job_id is not None, pop_results
try:
assert job_id is not None, pop_results
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/text/status/{req_id}", headers=headers)
print("Request cancelled")
raise err
submit_dict = {
"id": job_id,
"generation": "test ",
Expand Down
Loading