diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8add8913..b0acc038 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,14 +16,11 @@ repos: rev: 24.4.2 hooks: - id: black - exclude: ^hordelib/nodes/.*\..*$ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.4.3 hooks: - id: ruff -repos: - - repo: https://github.com/fsfe/reuse-tool +- repo: https://github.com/fsfe/reuse-tool rev: v4.0.3 hooks: - id: reuse - \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 15aa7b90..03b645bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,77 @@ SPDX-License-Identifier: AGPL-3.0-or-later # Changelog +# 4.44.3 + +* Fix image validation warnings being sent to the wrong requests +* Validate request with styles, only after style is applied + +# 4.44.2 + +* Allow trusted users to also create styles +* Fix styles always returning 1 image +* Fix style reward not being taken from the requesting user + +# 4.44.1 + +* Various fixes around styles +* Added `/.well-known/serviceinfo` + +# 4.44.0 + +* Adds styles +* Adds TabbyAPI as approved LLM backend + +# 4.43.9 + +* Prevent anon gens being visible at their profile page + +# 4.43.7 + +* fixes setting team for worker +* aborted jobs can't be restarted anymore + + +# 4.43.6 + +* Fix returning `done` when a job was restarted. + +# 4.43.5 + +* Fix: Added check to ensure the redis servers are still available. + +# 4.43.4 + +* Fix logic error when setting censored key + +# 4.43.3 + +* Add new `information` metadata key + +# 4.43.2 + +* Horde more accurately reports which images are nsfw or csam censored + +# 4.43.1 + +* Fix to prevent limit_max_steps picking up WPs with empty model lists + +# 4.43.0 + +* Adjused TTL formula to be algorithmic +* prevent workers without flux support picking up flux jobs +* Adds `extra_slow_workers` bool for image gen async +* Adds `extra_slow_worker` bool for worker pop +* Adds `limit_max_steps` for worker pop + +# 4.42.0 + +* Adds support for the Flux family of models + +# 4.41.0 + +* Adds support for extra backends behind LLM bridges, and for knowing which are validated. + # 4.40.3 * Ensure jobs don't expire soon after being picked up diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c201a563..68c256ba 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,6 +12,8 @@ SPDX-License-Identifier: AGPL-3.0-or-later * start server with `python server.py -vvvvi --horde stable` * You can now connect to http://localhost:7001 +To run the AI Horde with Docker or Docker Compose, see the [README_docker.md](README_docker.md). + # How to contribute to the AI Horde code We are happy you have ideas to improve this service and we welcome all contributors. @@ -85,4 +87,4 @@ Note that the pre-commit will not complain if you forget to add your copyright n # Code of Conduct -We expect all contributors to follow the [Anarchist code of conduct](https://wiki.dbzer0.com/the-anarchist-code-of-conduct/). Do not drive away other contributors due to intended or unintended on bigotry. \ No newline at end of file +We expect all contributors to follow the [Anarchist code of conduct](https://wiki.dbzer0.com/the-anarchist-code-of-conduct/). Do not drive away other contributors due to intended or unintended on bigotry. diff --git a/Dockerfile b/Dockerfile index 87bf9cf3..88ad21fd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,25 +1,52 @@ # SPDX-FileCopyrightText: 2024 Tazlin +# SPDX-FileCopyrightText: 2024 ceruleandeep # # SPDX-License-Identifier: AGPL-3.0-or-later # Use a slim base image for Python 3.10 -FROM python:3.10-slim +FROM python:3.10-slim AS python + + +## +## BUILD STAGE +## +FROM python AS python-build-stage # Install Git RUN apt-get update && apt-get install -y git -# Set the working directory +RUN --mount=type=cache,target=/root/.cache pip install --upgrade pip + +# Build dependencies +COPY ./requirements.txt . +RUN --mount=type=cache,target=/root/.cache \ + pip wheel --wheel-dir /usr/src/app/wheels \ + -r requirements.txt + + +## +## RUN STAGE +## +FROM python AS python-run-stage + +# git is required in the run stage because one dependency is not available in PyPI +RUN apt-get update && apt-get install -y git + +RUN --mount=type=cache,target=/root/.cache pip install --upgrade pip + +# Install dependencies +COPY --from=python-build-stage /usr/src/app/wheels /wheels/ +COPY ./requirements.txt . +RUN pip install --no-cache-dir --no-index --find-links=/wheels/ \ + -r requirements.txt \ + && rm -rf /wheels/ + WORKDIR /app -# Copy the source code to the container COPY . /app -# Install the dependencies -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --no-cache-dir --prefer-binary -r requirements.txt - # Set the environment variables -ENV PROFILE= +ENV PROFILE="" # Set the command to run when the container starts CMD ["python", "server.py", "-vvvvi", "--horde", "stable"] diff --git a/README_docker.md b/README_docker.md index 3a823ffc..3a5cfdff 100644 --- a/README_docker.md +++ b/README_docker.md @@ -22,16 +22,29 @@ Run the following command in your project root folder (the folder where your Doc docker build -t aihorde:latest . ``` -## with Docker-compose +## with Docker Compose -Create `.env_docker` file to deliver access information of services used together such as Redis and Postgres. +[docker-compose.yaml](docker-compose.yaml) is provided to run the AI-Horde with Redis and Postgres. -Copy the `.env_template` file in the root folder to create the .env_docker file. +Copy the `.env_template` file in the root folder to create the `.env_docker` file. -[docker-compose.yaml](docker-compose.yaml) Change the file as needed. +```bash +cp .env_template .env_docker +``` + +To use the supplied `.env_template` with the supplied `docker-compose.yaml`, you will need to set: + +```bash +# .env_docker +REDIS_IP="redis" +REDIS_SERVERS='["redis"]' +USE_SQLITE=0 +POSTGRES_URL="postgres" +``` +Then run the following command in your project root folder: ```bash # run in background -docker-compose up -d +docker compose up --build -d ``` diff --git a/docker-compose.yaml b/docker-compose.yaml index ccd88d2a..1ef32113 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,19 +1,28 @@ # SPDX-FileCopyrightText: 2024 Tazlin +# SPDX-FileCopyrightText: 2024 ceruleandeep # # SPDX-License-Identifier: AGPL-3.0-or-later -version: '3' services: aihorde: + build: + context: . + dockerfile: Dockerfile image: aihorde:latest container_name: aihorde ports: - "7001:7001" # The port number written in front of the colon (:) is the port number to be exposed to the outside, so if you change it, you can access it with localhost:{changePort}. environment: - - PROFILE=docker # If you write a profile, the .env_{PROFILE} file is read. + # Flask obtains its environment variables from the .env file. + # If you set a profile, the .env_{PROFILE} file is read instead. + - PROFILE=docker volumes: - - .env_docker:/app/.env_docker # You can replace the local pre-built .env file with the container's file. + # .env_{PROFILE} is copied into the image when it is built. + # So that you can change the environment variables without rebuilding the image, mount the .env file. + - .env_docker:/app/.env_docker + # Likewise, you can mount the horde directory to change the source code without rebuilding the image. + - ./horde:/app/horde networks: - aihorde_network depends_on: @@ -27,9 +36,10 @@ services: container_name: postgres restart: always environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: admin - POSTGRES_DB: postgres + POSTGRES_PASSWORD: changeme + volumes: + # Use a named volume to persist the data even if the container is deleted. + - postgres_data:/var/lib/postgresql/data/ ports: - "5432:5432" networks: @@ -47,3 +57,6 @@ services: networks: aihorde_network: driver: bridge + +volumes: + postgres_data: diff --git a/horde/apis/apiv2.py b/horde/apis/apiv2.py index d566c616..dabb9b69 100644 --- a/horde/apis/apiv2.py +++ b/horde/apis/apiv2.py @@ -6,12 +6,13 @@ from flask_restx import Api from horde.apis.v2 import api as v2 +from horde.consts import HORDE_API_VERSION from horde.vars import horde_contact_email, horde_title blueprint = Blueprint("apiv2", __name__, url_prefix="/api") api = Api( blueprint, - version="2.0", + version=str(HORDE_API_VERSION), title=f"{horde_title}", description=f"The API documentation for the {horde_title}", contact_email=horde_contact_email, diff --git a/horde/apis/models/kobold_v2.py b/horde/apis/models/kobold_v2.py index 52f9df56..4837cadd 100644 --- a/horde/apis/models/kobold_v2.py +++ b/horde/apis/models/kobold_v2.py @@ -107,10 +107,9 @@ def __init__(self, api): "generations": fields.List(fields.Nested(self.response_model_generation_result)), }, ) - self.root_model_generation_payload_kobold = api.model( - "ModelPayloadRootKobold", + self.root_model_generation_payload_style_kobold = api.model( + "ModelPayloadStyleKobold", { - "n": fields.Integer(example=1, min=1, max=20), "frmtadsnsp": fields.Boolean( example=False, description=( @@ -137,18 +136,6 @@ def __init__(self, api): "If the output is less than one sentence long, does nothing." ), ), - "max_context_length": fields.Integer( - min=80, - default=1024, - max=32000, - description="Maximum number of tokens to send to the model.", - ), - "max_length": fields.Integer( - min=16, - max=1024, - default=80, - description="Number of tokens to generate.", - ), "rep_pen": fields.Float(description="Base repetition penalty value.", min=1, max=3), "rep_pen_range": fields.Integer(description="Repetition penalty range.", min=0, max=4096), "rep_pen_slope": fields.Float(description="Repetition penalty slope.", min=0, max=10), @@ -159,12 +146,6 @@ def __init__(self, api): "including the newline." ), ), - # "soft_prompt": fields.String( - # description=( - # "Soft prompt to use when generating. If set to the empty string or any other string containing " - # "no non-whitespace characters, uses no soft prompt." - # ) - # ), "temperature": fields.Float(description="Temperature value.", min=0, max=5.0), "tfs": fields.Float(description="Tail free sampling value.", min=0.0, max=1.0), "top_a": fields.Float(description="Top-a sampling value.", min=0.0, max=1.0), @@ -202,6 +183,32 @@ def __init__(self, api): ), }, ) + self.root_model_generation_payload_kobold = api.inherit( + "ModelPayloadRootKobold", + self.root_model_generation_payload_style_kobold, + { + "n": fields.Integer(example=1, min=1, max=20), + "max_context_length": fields.Integer( + min=80, + default=1024, + max=32768, + description="Maximum number of tokens to send to the model.", + ), + "max_length": fields.Integer( + min=16, + max=1024, + default=80, + description="Number of tokens to generate.", + ), + # "soft_prompt": fields.String( + # description=( + # "Soft prompt to use when generating. If set to the empty string or any other string containing " + # "no non-whitespace characters, uses no soft prompt." + # ) + # ), + }, + ) + # The pop response playload self.response_model_generation_payload = api.inherit( "ModelPayloadKobold", self.root_model_generation_payload_kobold, @@ -209,6 +216,7 @@ def __init__(self, api): "prompt": fields.String(description="The prompt which will be sent to KoboldAI to generate the text."), }, ) + # The generation input self.input_model_generation_payload = api.inherit( "ModelGenerationInputKobold", self.root_model_generation_payload_kobold, @@ -242,7 +250,7 @@ def __init__(self, api): }, ) self.response_model_job_pop = api.model( - "GenerationPayload", + "GenerationPayloadKobold", { "payload": fields.Nested(self.response_model_generation_payload, skip_none=True), "id": fields.String(description="The UUID for this text generation."), @@ -252,6 +260,7 @@ def __init__(self, api): example="00000000-0000-0000-0000-000000000000", ), ), + "ttl": fields.Integer(description="The amount of seconds before this job is considered stale and aborted."), "extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)), "skipped": fields.Nested(self.response_model_generations_skipped, skip_none=True), "softprompt": fields.String(description="The soft prompt requested for this generation."), @@ -294,6 +303,14 @@ def __init__(self, api): "When False, Evaluating workers will also be used which can increase speed but adds more risk!" ), ), + "validated_backends": fields.Boolean( + default=True, + description=( + f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. " + "When False, non-validated backends will also be used which can increase speed but " + "you may end up with unexpected results." + ), + ), "slow_workers": fields.Boolean( default=True, description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.", @@ -340,6 +357,19 @@ def __init__(self, api): "The request will include the details of the job as well as the request ID." ), ), + "style": fields.String( + required=False, + max_length=1024, + min_length=3, + example="00000000-0000-0000-0000-000000000000", + description=("A horde style ID or name to use for this generation"), + ), + "extra_slow_workers": fields.Boolean( + default=False, + description=( + "When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot." + ), + ), }, ) self.response_model_contrib_details = api.inherit( @@ -427,3 +457,121 @@ def __init__(self, api): "gen_metadata": fields.List(fields.Nested(self.model_job_metadata)), }, ) + + # Styles + self.input_model_style_params = api.inherit( + "ModelStyleInputParamsKobold", + self.root_model_generation_payload_style_kobold, + {}, + ) + self.input_model_style = api.model( + "ModelStyleInputKobold", + { + "name": fields.String( + required=True, + example="My Awesome Text Style", + description="The name for the style. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "info": fields.String( + required=False, + description="Some information about this style.", + example="Dark, brooding vibes", + min_length=10, + max_length=1000, + ), + "prompt": fields.String( + required=False, + description=( + "The prompt template which will be sent to Stable Diffusion to generate an image. " + "The user's prompt will be injected into this." + " This argument MUST include a '{p}' which specifies the part where the user's prompt will be injected." + ), + default="{p}", + min_length=3, + ), + "params": fields.Nested(self.input_model_style_params, skip_none=True), + "public": fields.Boolean( + default=True, + description=( + "When true this style will be listed among all styles publicly." + "When false, information about this style can only be seen by people who know its ID or name." + ), + ), + "nsfw": fields.Boolean( + default=False, + description=("When true, it signified this style is expected to generare NSFW images primarily."), + ), + "tags": fields.List( + fields.String( + description="Tags describing this style. Used for filtering and discovery.", + min_length=1, + max_length=25, + example="dark", + ), + ), + "models": fields.List(fields.String(description="The models to use with this style.", min_length=1, example="llama3")), + }, + ) + self.patch_model_style = api.model( + "ModelStylePatchKobold", + { + "name": fields.String( + required=False, + example="My Awesome Text Style", + description="The name for the style. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "info": fields.String( + required=False, + example="Dark, brooding vibes", + description="Extra information about this style.", + min_length=1, + max_length=1000, + ), + "prompt": fields.String( + required=False, + description=( + "The prompt template which will be sent to Stable Diffusion to generate an image. " + "The user's prompt will be injected into this." + " This argument MUST include a '{p}' which specifies the part where the user's prompt will be injected." + ), + min_length=7, + ), + "params": fields.Nested(self.input_model_style_params, skip_none=True), + "public": fields.Boolean( + default=True, + description=( + "When true this style will be listed among all styles publicly." + "When false, information about this style can only be seen by people who know its ID or name." + ), + ), + "nsfw": fields.Boolean( + default=False, + description=("When true, it signified this style is expected to generare NSFW images primarily."), + ), + "tags": fields.List( + fields.String( + description="Tags describing this style. Used for filtering and discovery.", + min_length=1, + max_length=25, + example="dark", + ), + ), + "models": fields.List(fields.String(description="The models to use with this style.", min_length=1, example="llama3")), + }, + ) + self.response_model_style = api.inherit( + "StyleKobold", + self.input_model_style, + { + "id": fields.String( + description="The UUID of the style. Use this to use the style or retrieve its information in the future.", + example="00000000-0000-0000-0000-000000000000", + ), + "use_count": fields.Integer(description="The amount of times this style has been used in generations."), + "creator": fields.String(description="The alias of the user to whom this style belongs to.", example="db0#1"), + }, + ) diff --git a/horde/apis/models/stable_v2.py b/horde/apis/models/stable_v2.py index 563af6ca..d6f53ac7 100644 --- a/horde/apis/models/stable_v2.py +++ b/horde/apis/models/stable_v2.py @@ -152,6 +152,14 @@ def __init__(self): help="If True, this worker will pick up requests requesting LoRas.", location="json", ) + self.job_pop_parser.add_argument( + "limit_max_steps", + type=bool, + required=False, + default=False, + help="If True, This worker will not pick up jobs with more steps than the average allowed for that model.", + location="json", + ) self.job_submit_parser.add_argument( "seed", type=int, @@ -185,6 +193,7 @@ def __init__(self, api): "source_mask", "extra_source_images", "batch_index", + "information", ], description="The relevance of the metadata field", ), @@ -313,18 +322,13 @@ def __init__(self, api): }, ) self.input_model_special_payload = api.model("ModelSpecialPayloadStable", {"*": fields.Wildcard(fields.Raw)}) - self.root_model_generation_payload_stable = api.model( - "ModelPayloadRootStable", + self.root_model_generation_payload_style_stable = api.model( + "ModelPayloadStyleStable", { "sampler_name": fields.String(required=False, default="k_euler_a", enum=list(KNOWN_SAMPLERS)), "cfg_scale": fields.Float(required=False, default=7.5, min=0, max=100), "denoising_strength": fields.Float(required=False, example=0.75, min=0.01, max=1.0), "hires_fix_denoising_strength": fields.Float(required=False, example=0.75, min=0.01, max=1.0), - "seed": fields.String( - required=False, - example="The little seed that could", - description="The seed to use to generate this request. You can pass text as well as numbers.", - ), "height": fields.Integer( required=False, default=512, @@ -341,13 +345,6 @@ def __init__(self, api): max=3072, multiple=64, ), - "seed_variation": fields.Integer( - required=False, - example=1, - min=1, - max=1000, - description="If passed with multiple n, the provided seed will be incremented every time by this value.", - ), "post_processing": fields.List( fields.String( description="The list of post-processors to apply to the image, in the order to be applied.", @@ -377,6 +374,38 @@ def __init__(self, api): max=12, description="The number of CLIP language processor layers to skip.", ), + "facefixer_strength": fields.Float(required=False, example=0.75, min=0, max=1.0), + "loras": fields.List(fields.Nested(self.input_model_loras, skip_none=True)), + "tis": fields.List(fields.Nested(self.input_model_tis, skip_none=True)), + "special": fields.Nested(self.input_model_special_payload, skip_none=True), + "workflow": fields.String( + required=False, + default=None, + enum=list(KNOWN_WORKFLOWS), + description="Explicitly specify the horde-engine workflow to use.", + ), + "transparent": fields.Boolean( + default=False, + description="Set to True to generate the image using Layer Diffuse, creating an image with a transparent background.", + ), + }, + ) + self.root_model_generation_payload_stable = api.inherit( + "ModelPayloadRootStable", + self.root_model_generation_payload_style_stable, + { + "seed": fields.String( + required=False, + example="The little seed that could", + description="The seed to use to generate this request. You can pass text as well as numbers.", + ), + "seed_variation": fields.Integer( + required=False, + example=1, + min=1, + max=1000, + description="If passed with multiple n, the provided seed will be incremented every time by this value.", + ), "control_type": fields.String( required=False, enum=[ @@ -399,23 +428,10 @@ def __init__(self, api): default=False, description="Set to True if you want the ControlNet map returned instead of a generated image.", ), - "facefixer_strength": fields.Float(required=False, example=0.75, min=0, max=1.0), - "loras": fields.List(fields.Nested(self.input_model_loras, skip_none=True)), - "tis": fields.List(fields.Nested(self.input_model_tis, skip_none=True)), - "special": fields.Nested(self.input_model_special_payload, skip_none=True), "extra_texts": fields.List(fields.Nested(self.model_extra_texts)), - "workflow": fields.String( - required=False, - default=None, - enum=list(KNOWN_WORKFLOWS), - description="Explicitly specify the horde-engine workflow to use.", - ), - "transparent": fields.Boolean( - default=False, - description="Set to True to generate the image using Layer Diffuse, creating an image with a transparent background.", - ), }, ) + # The response for the pop self.response_model_generation_payload = api.inherit( "ModelPayloadStable", self.root_model_generation_payload_stable, @@ -430,6 +446,7 @@ def __init__(self, api): ), }, ) + # The input for the generation self.input_model_generation_payload = api.inherit( "ModelGenerationInputStable", self.root_model_generation_payload_stable, @@ -451,6 +468,9 @@ def __init__(self, api): "max_pixels": fields.Integer( description="How many waiting requests were skipped because they demanded a higher size than this worker provides.", ), + "step_count": fields.Integer( + description="How many waiting requests were skipped because they demanded a higher step count that the worker wants.", + ), "unsafe_ip": fields.Integer( description="How many waiting requests were skipped because they came from an unsafe IP.", ), @@ -482,6 +502,7 @@ def __init__(self, api): example="00000000-0000-0000-0000-000000000000", ), ), + "ttl": fields.Integer(description="The amount of seconds before this job is considered stale and aborted."), "skipped": fields.Nested(self.response_model_generations_skipped, skip_none=True), "model": fields.String(description="Which of the available models to use for this request."), "source_image": fields.String(description="The Base64-encoded webp to use for img2img."), @@ -544,6 +565,13 @@ def __init__(self, api): default=True, description="If True, this worker will pick up requests requesting LoRas.", ), + "limit_max_steps": fields.Boolean( + default=True, + description=( + "If True, This worker will not pick up jobs with more steps than the average allowed for that model." + " this is for use by workers which might run into issues doing too many steps." + ), + ), }, ) self.input_model_job_submit = api.inherit( @@ -579,10 +607,24 @@ def __init__(self, api): "When False, Evaluating workers will also be used which can increase speed but adds more risk!" ), ), + "validated_backends": fields.Boolean( + default=True, + description=( + f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. " + "When False, non-validated backends will also be used which can increase speed but " + "you may end up with unexpected results." + ), + ), "slow_workers": fields.Boolean( default=True, description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.", ), + "extra_slow_workers": fields.Boolean( + default=False, + description=( + "When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot." + ), + ), "censor_nsfw": fields.Boolean( default=False, description="If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image.", @@ -665,6 +707,13 @@ def __init__(self, api): "The request will include the details of the job as well as the request ID." ), ), + "style": fields.String( + required=False, + max_length=1024, + min_length=3, + example="00000000-0000-0000-0000-000000000000", + description=("A horde style ID or name to use for this generation"), + ), }, ) self.response_model_team_details = api.inherit( @@ -938,3 +987,170 @@ def __init__(self, api): "total": fields.Nested(self.response_model_model_stats), }, ) + + # Styles + self.input_model_style_params = api.inherit( + "ModelStyleInputParamsStable", + self.root_model_generation_payload_style_stable, + { + "steps": fields.Integer(default=30, required=False, min=1, max=500), + }, + ) + self.input_model_style = api.model( + "ModelStyleInputStable", + { + "name": fields.String( + required=True, + example="My Awesome Image Style", + description="The name for the style. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "info": fields.String( + required=False, + description="Some information about this style.", + example="photorealism excellence.", + min_length=10, + max_length=1000, + ), + "prompt": fields.String( + required=False, + description=( + "The prompt template which will be sent to Stable Diffusion to generate an image. " + "The user's prompt will be injected into this." + " This argument MUST include a '{p}' which specifies the part where the user's prompt will be injected " + "and an '{np}' where the user's negative prompt will be injected (if any)" + ), + default="{p}{np}", + min_length=7, + ), + "params": fields.Nested(self.input_model_style_params, skip_none=True), + "public": fields.Boolean( + default=True, + description=( + "When true this style will be listed among all styles publicly." + "When false, information about this style can only be seen by people who know its ID or name." + ), + ), + "nsfw": fields.Boolean( + default=False, + description=("When true, it signified this style is expected to generare NSFW images primarily."), + ), + "tags": fields.List( + fields.String( + description="Tags describing this style. Used for filtering and discovery.", + min_length=1, + max_length=25, + example="photorealistic", + ), + ), + "models": fields.List( + fields.String(description="The models to use with this style.", min_length=1, example="stable_diffusion"), + ), + }, + ) + self.patch_model_style = api.model( + "ModelStylePatchStable", + { + "name": fields.String( + required=False, + example="My Awesome Image Style", + description="The name for the style. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "info": fields.String( + required=False, + example="photorealism excellence.", + description="Extra information about this style.", + min_length=1, + max_length=1000, + ), + "prompt": fields.String( + required=False, + description=( + "The prompt template which will be sent to Stable Diffusion to generate an image. " + "The user's prompt will be injected into this." + " This argument MUST include a '{p}' which specifies the part where the user's prompt will be injected " + "and an '{np}' where the user's negative prompt will be injected (if any)" + ), + min_length=7, + ), + "params": fields.Nested(self.input_model_style_params, skip_none=True), + "public": fields.Boolean( + default=True, + description=( + "When true this style will be listed among all styles publicly." + "When false, information about this style can only be seen by people who know its ID or name." + ), + ), + "nsfw": fields.Boolean( + default=False, + description=("When true, it signified this style is expected to generare NSFW images primarily."), + ), + "tags": fields.List( + fields.String( + description="Tags describing this style. Used for filtering and discovery.", + min_length=1, + max_length=25, + example="photorealistic", + ), + ), + "models": fields.List( + fields.String(description="The models to use with this style.", min_length=1, example="stable_diffusion"), + ), + }, + ) + self.input_model_style_example_post = api.model( + "InputStyleExamplePost", + { + "url": fields.String( + example="https://lemmy.dbzer0.com/pictrs/image/c9915186-ca30-4f5a-873c-a91287fb4419.webp", + required=True, + description="Any extra information from the horde about this request.", + ), + "primary": fields.Boolean( + required=True, + default=False, + description="When true this image is to be used as the primary example for this style.", + ), + }, + ) + self.input_model_style_example_patch = api.model( + "InputStyleExamplePost", + { + "url": fields.String( + example="https://lemmy.dbzer0.com/pictrs/image/c9915186-ca30-4f5a-873c-a91287fb4419.webp", + required=False, + description="Any extra information from the horde about this request.", + ), + "primary": fields.Boolean( + required=False, + description="When true this image is to be used as the primary example for this style.", + ), + }, + ) + self.response_model_style_example = api.inherit( + "StyleExample", + self.input_model_style_example_post, + { + "id": fields.String( + example="00000000-0000-0000-0000-000000000000", + description="The UUID of this example.", + ), + }, + ) + + self.response_model_style = api.inherit( + "StyleStable", + self.input_model_style, + { + "id": fields.String( + description="The UUID of the style. Use this to use the style or retrieve its information in the future.", + example="00000000-0000-0000-0000-000000000000", + ), + "use_count": fields.Integer(description="The amount of times this style has been used in generations."), + "creator": fields.String(description="The alias of the user to whom this style belongs to.", example="db0#1"), + "examples": fields.List(fields.Nested(self.response_model_style_example, skip_none=True)), + }, + ) diff --git a/horde/apis/models/v2.py b/horde/apis/models/v2.py index 72526079..2ab2257b 100644 --- a/horde/apis/models/v2.py +++ b/horde/apis/models/v2.py @@ -11,6 +11,29 @@ class Parsers: def __init__(self): + # A Basic parser which only expects a Client-Agent + self.basic_parser = reqparse.RequestParser() + self.basic_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + + # A Basic parser which only expects a Client-Agent and an API Key + self.apikey_parser = reqparse.RequestParser() + self.apikey_parser.add_argument("apikey", type=str, required=True, help="A mod API key.", location="headers") + self.apikey_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + self.generate_parser = reqparse.RequestParser() self.generate_parser.add_argument( "apikey", @@ -57,6 +80,16 @@ def __init__(self): "When False, Evaluating workers will also be used.", location="json", ) + self.generate_parser.add_argument( + "validated_backends", + type=bool, + required=False, + default=False, + help=f"When true, only inference backends that are validated by the {horde_title} devs will serve this request. " + "When False, non-validated backends will also be used which can increase speed but " + "you may end up with unexpected results.", + location="json", + ) self.generate_parser.add_argument( "workers", type=list, @@ -91,6 +124,14 @@ def __init__(self): help="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.", location="json", ) + self.generate_parser.add_argument( + "extra_slow_workers", + type=bool, + default=False, + required=False, + help="When True, allows very slower workers to pick up this request. Use this when you don't mind waiting a lot.", + location="json", + ) self.generate_parser.add_argument( "dry_run", type=bool, @@ -124,6 +165,7 @@ def __init__(self): location="json", ) self.generate_parser.add_argument("webhook", type=str, required=False, location="json") + self.generate_parser.add_argument("style", type=str, required=False, location="json") # The parser for RequestPop self.job_pop_parser = reqparse.RequestParser() @@ -194,6 +236,13 @@ def __init__(self): help="How many jobvs to pop at the same time", location="json", ) + self.job_pop_parser.add_argument( + "extra_slow_worker", + type=bool, + default=False, + required=False, + location="json", + ) self.job_submit_parser = reqparse.RequestParser() self.job_submit_parser.add_argument( @@ -233,6 +282,153 @@ def __init__(self): location="json", ) + # Style Parsers + self.style_parser = reqparse.RequestParser() + self.style_parser.add_argument( + "apikey", + type=str, + required=True, + help="The API Key corresponding to a registered user.", + location="headers", + ) + self.style_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + self.style_parser.add_argument( + "name", + type=str, + required=True, + help="The name of the style.", + location="json", + ) + self.style_parser.add_argument( + "info", + type=str, + required=False, + help="Extra information about this style.", + location="json", + ) + self.style_parser.add_argument( + "prompt", + type=str, + required=False, + default="{p}{np}", + help="The prompt to generate from.", + location="json", + ) + self.style_parser.add_argument( + "params", + type=dict, + required=False, + help="Extra generate params to send to the worker.", + location="json", + ) + self.style_parser.add_argument( + "public", + type=bool, + default=True, + required=False, + location="json", + ) + self.style_parser.add_argument( + "nsfw", + type=bool, + default=False, + required=False, + location="json", + ) + self.style_parser.add_argument( + "tags", + type=list, + required=False, + help="Tags describing this style. Can be used for style discovery.", + location="json", + ) + self.style_parser.add_argument( + "models", + type=list, + required=False, + help="Tags describing this style. Can be used for style discovery.", + location="json", + ) + self.style_parser_patch = reqparse.RequestParser() + self.style_parser_patch.add_argument( + "apikey", + type=str, + required=True, + help="The API Key corresponding to a registered user.", + location="headers", + ) + self.style_parser_patch.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + self.style_parser_patch.add_argument( + "name", + type=str, + required=False, + help="The name of the style.", + location="json", + ) + self.style_parser_patch.add_argument( + "info", + type=str, + required=False, + help="Extra information about this style.", + location="json", + ) + self.style_parser_patch.add_argument( + "prompt", + type=str, + required=False, + help="The prompt to generate from.", + location="json", + ) + self.style_parser_patch.add_argument( + "params", + type=dict, + required=False, + help="Extra generate params to send to the worker.", + location="json", + ) + self.style_parser_patch.add_argument( + "public", + type=bool, + default=True, + required=False, + location="json", + ) + self.style_parser_patch.add_argument( + "nsfw", + type=bool, + default=False, + required=False, + location="json", + ) + self.style_parser_patch.add_argument( + "tags", + type=list, + required=False, + help="Tags describing this style. Can be used for style discovery.", + location="json", + ) + self.style_parser_patch.add_argument( + "models", + type=list, + required=False, + help="Tags describing this style. Can be used for style discovery.", + location="json", + ) + class Models: def __init__(self, api): @@ -413,6 +609,7 @@ def __init__(self, api): { "payload": fields.Nested(self.response_model_generation_payload, skip_none=True), "id": fields.String(description="The UUID for this generation."), + "ttl": fields.Integer(description="The amount of seconds before this job is considered stale and aborted."), "skipped": fields.Nested(self.response_model_generations_skipped, skip_none=True), }, ) @@ -527,6 +724,13 @@ def __init__(self, api): min=1, max=20, ), + "extra_slow_worker": fields.Boolean( + default=True, + description=( + "If True, marks the worker as very slow. You should only use this if your mps/s is lower than 0.1." + "Extra slow workers are excluded from normal requests but users can opt in to use them." + ), + ), }, ) self.response_model_worker_details = api.inherit( @@ -728,6 +932,10 @@ def __init__(self, api): default=0, description="The amount of Kudos this user has been awarded from things like rating images.", ), + "styled": fields.Float( + default=0, + description="The amount of Kudos this user has been awarded for styling other people's requests.", + ), }, ) @@ -856,11 +1064,11 @@ def __init__(self, api): "UserAmountRecords", { "image": fields.Integer( - description="How many images this user has generated or requested.", + description="How many images this user has generated, requested or styled.", default=0, ), "text": fields.Integer( - description="How many texts this user has generated or requested.", + description="How many texts this user has generated, requested or styled.", default=0, ), "interrogation": fields.Integer( @@ -877,6 +1085,7 @@ def __init__(self, api): "contribution": fields.Nested(self.response_model_user_thing_records), "fulfillment": fields.Nested(self.response_model_user_amount_records), "request": fields.Nested(self.response_model_user_amount_records), + "style": fields.Nested(self.response_model_user_amount_records), }, ) @@ -904,6 +1113,31 @@ def __init__(self, api): }, ) + self.response_model_styles_short = api.model( + "ResponseModelStylesShort", + { + "name": fields.String( + description="The unique name for this style", + example="db0#1::style::my awesome style", + ), + "id": fields.String( + description="The ID of this style", + example="00000000-0000-0000-0000-000000000000", + ), + }, + ) + + self.response_model_styles_user = api.inherit( + "ResponseModelStylesUser", + self.response_model_styles_short, + { + "type": fields.String( + description="The style type, image or text", + enum=["image", "text"], + ), + }, + ) + self.response_model_user_details = api.model( "UserDetails", { @@ -941,6 +1175,7 @@ def __init__(self, api): example="00000000-0000-0000-0000-000000000000", ), ), + "styles": fields.List(fields.Nested(self.response_model_styles_user)), "sharedkey_ids": fields.List( fields.String( description="(Privileged) The list of shared key IDs created by this user.", @@ -1558,6 +1793,85 @@ def __init__(self, api): description="The recommended type of worker.", enum=["image", "text"], required=True, + ),} + ) + # Styles + self.response_model_styles_post = api.model( + "StyleModify", + { + "id": fields.String( + example="00000000-0000-0000-0000-000000000000", + description="The UUID of the style. Use this to use this style of retrieve its information in the future.", + ), + "message": fields.String( + default=None, + description="Any extra information from the horde about this request.", + ), + "warnings": fields.List(fields.Nested(self.response_model_warning)), + }, + ) + + # Collections + + self.input_model_collection = api.model( + "InputModelCollection", + { + "name": fields.String( + required=False, + example="My Awesome Collection", + description="The name for the collection. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "info": fields.String( + required=False, + example="Collection of optimistic styles", + description="Extra information about this collection.", + min_length=1, + max_length=1000, + ), + "public": fields.Boolean( + default=True, + description=( + "When true this collection will be listed among all collections publicly." + "When false, information about this collection can only be seen by people who know its ID or name." + ), + ), + "styles": fields.List(fields.String(description="The styles to use in this collection.", min_length=1)), + }, + ) + + self.response_model_collection = api.model( + "ResponseModelCollection", + { + "id": fields.String( + description="The UUID of the collection. Use this to use this collection of retrieve its information in the future.", + ), + "name": fields.String( + required=False, + description="The name for the collection. Case-sensitive and unique per user.", + min_length=1, + max_length=100, + ), + "type": fields.String( + required=False, + description="The kind of styles stored in this collection.", + enum=["image", "text"], + ), + "info": fields.String( + required=False, + description="Extra information about this collection.", + min_length=1, + max_length=1000, + ), + "public": fields.Boolean( + default=True, + description=( + "When true this collection will be listed among all collection publicly." + "When false, information about this collection can only be seen by people who know its ID or name." + ), ), + "styles": fields.List(fields.Nested(self.response_model_styles_short)), + "use_count": fields.Integer(description="The amount of times this collection has been used in generations."), }, ) diff --git a/horde/apis/v2/__init__.py b/horde/apis/v2/__init__.py index 4704a704..db53e60c 100644 --- a/horde/apis/v2/__init__.py +++ b/horde/apis/v2/__init__.py @@ -13,10 +13,21 @@ api.add_resource(stable.Aesthetics, "/generate/rate/") api.add_resource(stable.ImageJobPop, "/generate/pop") api.add_resource(stable.ImageJobSubmit, "/generate/submit") +api.add_resource(stable.ImageStyle, "/styles/image") +api.add_resource(stable.SingleImageStyle, "/styles/image/") +api.add_resource(stable.SingleImageStyleByName, "/styles/image_by_name/") +api.add_resource(stable.ImageStyleExample, "/styles/image//example") +api.add_resource(stable.SingleImageStyleExample, "/styles/image//example/") api.add_resource(kobold.TextAsyncGenerate, "/generate/text/async") api.add_resource(kobold.TextAsyncStatus, "/generate/text/status/") api.add_resource(kobold.TextJobPop, "/generate/text/pop") api.add_resource(kobold.TextJobSubmit, "/generate/text/submit") +api.add_resource(kobold.TextStyle, "/styles/text") +api.add_resource(kobold.SingleTextStyle, "/styles/text/") +api.add_resource(kobold.SingleImageStyleByName, "/styles/text_by_name/") +api.add_resource(base.Collection, "/collections") +api.add_resource(base.SingleCollection, "/collections/") +api.add_resource(base.SingleCollectionByName, "/collection_by_name/") api.add_resource(base.Users, "/users") api.add_resource(base.UserSingle, "/users/") api.add_resource(base.FindUser, "/find_user") @@ -24,6 +35,7 @@ api.add_resource(base.SharedKeySingle, "/sharedkeys/") api.add_resource(base.Workers, "/workers") api.add_resource(base.WorkerSingle, "/workers/") +api.add_resource(base.WorkerSingleName, "/workers/name/") api.add_resource(base.TransferKudos, "/kudos/transfer") api.add_resource(base.AwardKudos, "/kudos/award") api.add_resource(base.HordeModes, "/status/modes") diff --git a/horde/apis/v2/base.py b/horde/apis/v2/base.py index d49c2dd3..dd4a7be0 100644 --- a/horde/apis/v2/base.py +++ b/horde/apis/v2/base.py @@ -18,13 +18,13 @@ import horde.apis.limiter_api as lim import horde.classes.base.stats as stats from horde import exceptions as e -from horde import horde_redis as hr from horde.apis.models.v2 import Models, Parsers from horde.argparser import args from horde.classes.base import settings from horde.classes.base.detection import Filter from horde.classes.base.news import News -from horde.classes.base.team import Team, get_all_teams +from horde.classes.base.style import StyleCollection +from horde.classes.base.team import Team, find_team_by_id, get_all_teams from horde.classes.base.user import User, UserSharedKey from horde.classes.base.waiting_prompt import WaitingPrompt from horde.classes.base.worker import Worker @@ -33,6 +33,7 @@ from horde.database import functions as database from horde.detection import prompt_checker from horde.flask import HORDE, cache, db +from horde.horde_redis import horde_redis as hr from horde.image import ensure_source_image_uploaded from horde.limiter import limiter from horde.logger import logger @@ -40,7 +41,7 @@ from horde.patreon import patrons from horde.r2 import upload_prompt from horde.suspicions import Suspicions -from horde.utils import hash_api_key, hash_dictionary, is_profane, sanitize_string +from horde.utils import ensure_clean, hash_api_key, hash_dictionary, is_profane, sanitize_string from horde.vars import horde_contact_email, horde_title, horde_url # Not used yet @@ -122,6 +123,7 @@ def post(self): # It causes them to be a shared object from the parsers class self.params = {} self.warnings = set() + self.style_kudos = False if self.args.params: self.params = self.args.params self.models = [] @@ -256,40 +258,40 @@ def validate(self): if ip_timeout: raise e.TimeoutIP(self.user_ip, ip_timeout) # logger.warning(datetime.utcnow()) - prompt_suspicion, _ = prompt_checker(self.args.prompt) + prompt_suspicion, _ = prompt_checker(self.prompt) # logger.warning(datetime.utcnow()) prompt_replaced = False if prompt_suspicion >= 2 and self.gentype != "text": # if replacement filter mode is enabled AND prompt is short enough, do that instead if self.args.replacement_filter or self.user.education: - if not prompt_checker.check_prompt_replacement_length(self.args.prompt): + if not prompt_checker.check_prompt_replacement_length(self.prompt): raise e.BadRequest("Prompt has to be below 7000 chars when replacement filter is on") - self.args.prompt = prompt_checker.apply_replacement_filter(self.args.prompt) + self.prompt = prompt_checker.apply_replacement_filter(self.prompt) # If it returns None, it means it replaced everything with an empty string - if self.args.prompt is not None: + if self.prompt is not None: prompt_replaced = True if not prompt_replaced: # Moderators do not get ip blocked to allow for experiments if not self.user.moderator: prompt_dict = { - "prompt": self.args.prompt, + "prompt": self.prompt, "user": self.username, "type": "regex", } upload_prompt(prompt_dict) self.user.report_suspicion(1, Suspicions.CORRUPT_PROMPT) CounterMeasures.report_suspicion(self.user_ip) - raise e.CorruptPrompt(self.username, self.user_ip, self.args.prompt) - if_nsfw_model = prompt_checker.check_nsfw_model_block(self.args.prompt, self.models) + raise e.CorruptPrompt(self.username, self.user_ip, self.prompt) + if_nsfw_model = prompt_checker.check_nsfw_model_block(self.prompt, self.models) if if_nsfw_model or self.user.flagged: # For NSFW models and flagged users, we always do replacements # This is to avoid someone using the NSFW models to figure out the regex since they don't have an IP timeout - self.args.prompt = prompt_checker.nsfw_model_prompt_replace( - self.args.prompt, + self.prompt = prompt_checker.nsfw_model_prompt_replace( + self.prompt, self.models, already_replaced=prompt_replaced, ) - if self.args.prompt is None: + if self.prompt is None: prompt_replaced = False elif prompt_replaced is False: prompt_replaced = True @@ -301,16 +303,16 @@ def validate(self): ) if self.user.flagged and not if_nsfw_model: msg = "To prevent generation of unethical images, we cannot allow this prompt." - raise e.CorruptPrompt(self.username, self.user_ip, self.args.prompt, message=msg) + raise e.CorruptPrompt(self.username, self.user_ip, self.prompt, message=msg) # Disabling as this is handled by the worker-csam-filter now # If I re-enable it, also make it use the prompt replacement # if not prompt_replaced: - # csam_trigger_check = prompt_checker.check_csam_triggers(self.args.prompt) + # csam_trigger_check = prompt_checker.check_csam_triggers(self.prompt) # if csam_trigger_check is not False and self.gentype != "text": # raise e.CorruptPrompt( # self.username, # self.user_ip, - # self.args.prompt, + # self.prompt, # message = (f"The trigger '{csam_trigger_check}' has been detected to generate " # "unethical images on its own and as such has had to be prevented from use. " # "Thank you for understanding.") @@ -332,6 +334,7 @@ def initiate_waiting_prompt(self): nsfw=self.args.nsfw, censor_nsfw=self.args.censor_nsfw, trusted_workers=self.args.trusted_workers, + validated_backends=self.args.validated_backends, worker_blacklist=self.args.worker_blacklist, ipaddr=self.user_ip, sharedkey_id=self.args.apikey if self.sharedkey else None, @@ -347,7 +350,11 @@ def activate_waiting_prompt(self): _, _, ) = ensure_source_image_uploaded(eimg["image"], f"{self.wp.id}_exra_src_{iiter}", force_r2=True) - self.wp.activate(self.downgrade_wp_priority, extra_source_images=self.args.extra_source_images) + self.wp.activate( + self.downgrade_wp_priority, + extra_source_images=self.args.extra_source_images, + kudos_adjustment=2 if self.style_kudos is True else 0, + ) class SyncGenerate(GenerateTemplate): @@ -455,7 +462,6 @@ def post(self): # as they're typically countermeasures to raids if skipped_reason != "secret": self.skipped[skipped_reason] = self.skipped.get(skipped_reason, 0) + 1 - # logger.warning(datetime.utcnow()) continue # There is a chance that by the time we finished all the checks, another worker picked up the WP. @@ -476,7 +482,7 @@ def post(self): # We report maintenance exception only if we couldn't find any jobs if self.worker.maintenance: raise e.WorkerMaintenance(self.worker.maintenance_msg) - # logger.warning(datetime.utcnow()) + # logger.debug(self.skipped) return {"id": None, "ids": [], "skipped": self.skipped}, 200 def get_sorted_wp(self, priority_user_ids=None): @@ -768,6 +774,15 @@ class Workers(Resource): location="args", ) + get_parser.add_argument( + "name", + required=False, + default=None, + type=str, + help="Find a worker by name (case insensitive).", + location="args", + ) + @api.expect(get_parser) @logger.catch(reraise=True) # @cache.cached(timeout=10, query_string=True) @@ -813,12 +828,83 @@ def get_worker_info_list(self, details_privilege): return workers_ret def parse_worker_by_query(self, workers_list): - if not self.args.type: - return workers_list - return [w for w in workers_list if w["type"] == self.args.type] + if self.args.name: + return [w for w in workers_list if w["name"].lower() == self.args.name.lower()] + if self.args.type: + return [w for w in workers_list if w["type"] == self.args.type] + return workers_list + +class WorkerSingleBase(Resource): -class WorkerSingle(Resource): + def get_worker_by_id(self, worker_id): + cache_exists = True + details_privilege = 0 + if self.args.apikey: + admin = database.find_user_by_api_key(self.args["apikey"]) + if admin and admin.moderator: + details_privilege = 2 + if not hr.horde_r: + cache_exists = False + if details_privilege > 0: + cache_name = f"cached_worker_{worker_id}_privileged" + cached_worker = hr.horde_r_get(cache_name) + else: + cache_name = f"cached_worker_{worker_id}" + cached_worker = hr.horde_r_get(cache_name) + if cache_exists and cached_worker: + worker_details = json.loads(cached_worker) + else: + worker = database.find_worker_by_id(worker_id) + if not worker: + raise e.WorkerNotFound(worker_id) + worker_details = worker.get_details(details_privilege) + hr.horde_r_setex_json(cache_name, timedelta(seconds=30), worker_details) + return worker_details + + +class WorkerSingleName(WorkerSingleBase): + get_parser = reqparse.RequestParser() + get_parser.add_argument( + "apikey", + type=str, + required=False, + help="The Moderator or Owner API key.", + location="headers", + ) + get_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version.", + location="headers", + ) + + @api.expect(get_parser) + # @cache.cached(timeout=10) + @api.marshal_with( + models.response_model_worker_details, + code=200, + description="Worker Details", + skip_none=True, + ) + @api.response(401, "Invalid API Key", models.response_model_error) + @api.response(403, "Access Denied", models.response_model_error) + @api.response(404, "Worker Not Found", models.response_model_error) + def get(self, worker_name=""): + """Details of a registered worker + Can retrieve the details of a worker even if inactive + (A worker is considered inactive if it has not checked in for 5 minutes) + """ + self.args = self.get_parser.parse_args() + worker = database.find_worker_id_by_name(worker_name) + if not worker: + raise e.WorkerNotFound(worker_name) + return self.get_worker_by_id(str(worker.id)), 200 + + +class WorkerSingle(WorkerSingleBase): get_parser = reqparse.RequestParser() get_parser.add_argument( "apikey", @@ -852,30 +938,8 @@ def get(self, worker_id=""): Can retrieve the details of a worker even if inactive (A worker is considered inactive if it has not checked in for 5 minutes) """ - cache_exists = True - details_privilege = 0 self.args = self.get_parser.parse_args() - if self.args.apikey: - admin = database.find_user_by_api_key(self.args["apikey"]) - if admin and admin.moderator: - details_privilege = 2 - if not hr.horde_r: - cache_exists = False - if details_privilege > 0: - cache_name = f"cached_worker_{worker_id}_privileged" - cached_worker = hr.horde_r_get(cache_name) - else: - cache_name = f"cached_worker_{worker_id}" - cached_worker = hr.horde_r_get(cache_name) - if cache_exists and cached_worker: - worker_details = json.loads(cached_worker) - else: - worker = database.find_worker_by_id(worker_id) - if not worker: - raise e.WorkerNotFound(worker_id) - worker_details = worker.get_details(details_privilege) - hr.horde_r_setex_json(cache_name, timedelta(seconds=30), worker_details) - return worker_details, 200 + return self.get_worker_by_id(worker_id), 200 put_parser = reqparse.RequestParser() put_parser.add_argument( @@ -1014,7 +1078,7 @@ def put(self, worker_id=""): worker.set_team(None) ret_dict["team"] = "None" else: - team = database.find_team_by_id(self.args.team) + team = find_team_by_id(self.args.team) if not team: raise e.TeamNotFound(self.args.team) ret = worker.set_team(team) @@ -1535,7 +1599,7 @@ def get(self): skname = f": {sk.name}" user_details["username"] = user_details["username"] + f" (Shared Key{skname})" if hr.horde_r: - hr.horde_r_setex_json(cache_name, timedelta(seconds=300), user_details) + hr.horde_r_setex_json(cache_name, timedelta(seconds=30), user_details) return (user_details, 200) @@ -1945,7 +2009,7 @@ class TeamSingle(Resource): @api.response(404, "Team Not Found", models.response_model_error) def get(self, team_id=""): """Details of a worker Team""" - team = database.find_team_by_id(team_id) + team = find_team_by_id(team_id) if not team: raise e.TeamNotFound(team_id) details_privilege = 0 @@ -1993,7 +2057,7 @@ def get(self, team_id=""): @api.response(404, "Team Not Found", models.response_model_error) def patch(self, team_id=""): """Update a Team's information""" - team = database.find_team_by_id(team_id) + team = find_team_by_id(team_id) if not team: raise e.TeamNotFound(team_id) self.args = self.patch_parser.parse_args() @@ -2049,7 +2113,7 @@ def delete(self, team_id=""): Only the team's creator or a horde moderator can use this endpoint. This action is unrecoverable! """ - team = database.find_team_by_id(team_id) + team = find_team_by_id(team_id) if not team: raise e.TeamNotFound(team_id) self.args = self.delete_parser.parse_args() @@ -3080,6 +3144,163 @@ class AutoWorkerType(Resource): help="A User API key.", location="headers", ) + +## Styles + + +class StyleTemplate(Resource): + gentype = "template" + args = None + + def get(self): + if self.args.sort not in ["popular", "age"]: + raise e.BadRequest("'model_state' needs to be one of ['popular', 'age']") + styles_ret = database.retrieve_available_styles( + style_type=self.gentype, + sort=self.args.sort, + page=self.args.page - 1, + tag=self.args.tag, + model=self.args.model, + ) + styles_ret = [st.get_details() for st in styles_ret] + return styles_ret, 200 + + def post(self): + # I have to extract and store them this way, because if I use the defaults + # It causes them to be a shared object from the parsers class + self.params = {} + self.warnings = set() + if self.args.params: + self.params = self.args.params + # For styles, we just store the models in the params + self.models = [] + if self.args.models: + self.params["models"] = self.args.models.copy() + self.user = None + self.validate() + return + + def validate(self): + pass + + +class SingleStyleTemplateGet(Resource): + gentype = "template" + + def get_existing_style(self): + if self.existing_style.style_type != self.gentype: + raise e.BadRequest( + f"Style was found but was of the wrong type: {self.existing_style.style_type} != {self.gentype}", + "StyleGetMistmatch", + ) + return self.existing_style.get_details() + + def get_through_id(self, style_id): + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound(f"{self.gentype} Style", style_id) + return self.get_existing_style() + + +class SingleStyleTemplate(SingleStyleTemplateGet): + + def patch(self, style_id): + self.params = {} + self.warnings = set() + self.args = parsers.style_parser.parse_args() + if self.args.params: + self.params = self.args.params + # For styles, we just store the models in the params + self.models = [] + style_modified = False + self.tags = [] + if self.args.tags: + self.tags = self.args.tags.copy() + if len(self.tags) > 10: + raise e.BadRequest("A style can be tagged a maximum of 10 times.") + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Style PATCH") + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_id) + if self.existing_style.user_id != self.user.id: + raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}") + if self.args.models: + self.models = self.args.models.copy() + if len(self.models) > 5: + raise e.BadRequest("A style can only use a maximum of 5 models.") + if len(self.models) < 1: + raise e.BadRequest("A style has to specify at least one model.") + else: + self.models = self.existing_style.get_model_names() + self.style_name = None + if self.args.name: + self.style_name = ensure_clean(self.args.name, "style name") + style_modified = True + self.validate() + self.existing_style.name = self.style_name + if self.args.info is not None: + self.existing_style.info = ensure_clean(self.args.info, "style info") + style_modified = True + if self.args.public is not None: + self.existing_style.public = self.args.public + style_modified = True + if self.args.nsfw is not None: + self.existing_style.nsfw = self.args.nsfw + style_modified = True + if self.args.prompt is not None: + self.existing_style.prompt = self.args.prompt + style_modified = True + if self.args.params is not None: + self.existing_style.params = self.args.params + style_modified = True + if len(self.models) > 0: + style_modified = True + if len(self.tags) > 0: + style_modified = True + if not style_modified: + return { + "id": self.existing_style.id, + "message": "OK", + }, 200 + db.session.commit() + self.existing_style.set_models(self.models) + self.existing_style.set_tags(self.tags) + return { + "id": self.existing_style.id, + "message": "OK", + "warnings": self.warnings, + }, 200 + + def validate(self): + pass + + def delete(self, style_id): + self.args = parsers.apikey_parser.parse_args() + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Style DELETE") + if self.user.is_anon(): + raise e.Forbidden("Anonymous users cannot delete styles", rc="StylesAnonForbidden") + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_id) + if self.existing_style.user_id != self.user.id and not self.user.moderator: + raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}") + if self.existing_style.user_id != self.user.id and self.user.moderator: + logger.info(f"Moderator {self.user.moderator} deleted style {self.existing_style.id}") + self.existing_style.delete() + return ({"message": "OK"}, 200) + + +## Collections + + +class Collection(Resource): + args = None + + get_parser = reqparse.RequestParser() get_parser.add_argument( "Client-Agent", default="unknown:0:unknown", @@ -3122,4 +3343,339 @@ def get(self): if image_workers_count > text_workers_count: return {"recommended_worker_type": "text"}, 200 else: - return {"recommended_worker_type": "image"}, 200 \ No newline at end of file + return {"recommended_worker_type": "image"}, 200 + get_parser.add_argument( + "sort", + required=False, + default="popular", + type=str, + help="How to sort returned styles. 'popular' sorts by usage and 'age' sorts by date added.", + location="args", + ) + get_parser.add_argument( + "page", + required=False, + default=1, + type=int, + help="Which page of results to return. Each page has 25 styles.", + location="args", + ) + get_parser.add_argument( + "type", + required=False, + default="all", + type=str, + help="Filter by type. Accepts either 'image', 'text' or 'all'.", + location="args", + ) + + @cache.cached(timeout=30, query_string=True) + @api.expect(get_parser) + @api.marshal_with( + models.response_model_collection, + code=200, + description="Lists collection information", + as_list=True, + ) + def get(self): + """Displays all existing collections. Can filter by type""" + self.args = self.get_parser.parse_args() + if self.args.sort not in ["popular", "age"]: + raise e.BadRequest("'model_state' needs to be one of ['popular', 'age']") + if self.args.type not in ["all", "image", "text"]: + raise e.BadRequest("'type' needs to be one of ['all', 'image', 'text']") + collections = database.retrieve_available_collections( + sort=self.args.sort, + page=self.args.page - 1, + collection_type=self.args.type if self.args.type in ["image", "text"] else None, + ) + collections_ret = [co.get_details() for co in collections] + return collections_ret, 200 + + post_parser = reqparse.RequestParser() + post_parser.add_argument( + "apikey", + type=str, + required=True, + help="The API Key corresponding to a registered user.", + location="headers", + ) + post_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + post_parser.add_argument( + "name", + type=str, + required=True, + location="json", + ) + post_parser.add_argument( + "info", + type=str, + required=False, + location="json", + ) + post_parser.add_argument( + "public", + type=bool, + default=True, + required=False, + location="json", + ) + post_parser.add_argument( + "styles", + type=list, + required=True, + location="json", + ) + + decorators = [ + limiter.limit( + limit_value=lim.get_request_90min_limit_per_ip, + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(post_parser, models.input_model_collection, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Collection Added", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def post(self): + """Creates a new style collection.""" + self.warnings = set() + # For styles, we just store the models in the params + self.styles = [] + styles_type = None + self.args = self.post_parser.parse_args() + if self.args.styles: + if len(self.args.styles) < 1: + raise e.BadRequest("A collection has to include at least 1 style") + else: + raise e.BadRequest("A collection has to include at least 1 style") + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Collection POST") + if self.user.is_anon(): + raise e.Forbidden("Anonymous users cannot create collections", rc="StylesAnonForbidden") + for st in self.args.styles: + existing_style = database.get_style_by_uuid(st, is_collection=False) + if not existing_style: + existing_style = database.get_style_by_name(st, is_collection=False) + if not existing_style: + raise e.BadRequest(f"A style with name '{st}' cannot be found") + if styles_type is None: + styles_type = existing_style.style_type + elif styles_type != existing_style.style_type: + raise e.BadRequest("Cannot mix image and text styles in the same collection") + self.styles.append(existing_style) + self.collection_name = ensure_clean(self.args.name, "collection name") + new_collection = StyleCollection( + user_id=self.user.id, + style_type=styles_type, + info=ensure_clean(self.args.info, "collection info") if self.args.info is not None else "", + name=self.collection_name, + public=self.args.public, + ) + new_collection.create(self.styles) + return { + "id": new_collection.id, + "message": "OK", + "warnings": self.warnings, + }, 200 + + +class SingleCollectionGet(Resource): + + def get_through_id(self, style_id): + self.existing_collection = database.get_style_by_uuid(style_id, is_collection=True) + if not self.existing_collection: + raise e.ThingNotFound("Collection", style_id) + return self.existing_collection.get_details() + + +class SingleCollection(SingleCollectionGet): + args = None + + @cache.cached(timeout=30, query_string=True) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_collection, + code=200, + description="Lists collection information", + as_list=False, + ) + def get(self, collection_id): + """Displays information about a single style collection.""" + return super().get_through_id(collection_id) + + patch_parser = reqparse.RequestParser() + patch_parser.add_argument( + "apikey", + type=str, + required=True, + help="The API Key corresponding to a registered user.", + location="headers", + ) + patch_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version", + location="headers", + ) + patch_parser.add_argument( + "name", + type=str, + required=False, + location="json", + ) + patch_parser.add_argument( + "info", + type=str, + required=False, + location="json", + ) + patch_parser.add_argument( + "public", + type=bool, + required=False, + location="json", + ) + patch_parser.add_argument( + "styles", + type=list, + required=False, + location="json", + ) + + decorators = [ + limiter.limit( + limit_value=lim.get_request_90min_limit_per_ip, + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(patch_parser, models.input_model_collection, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Collection Modified", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def patch(self, collection_id): + """Modifies an existing style collection.""" + self.warnings = set() + # For styles, we just store the models in the params + self.styles = [] + styles_type = None + self.args = self.patch_parser.parse_args() + if self.args.styles: + if len(self.args.styles) < 1: + raise e.BadRequest("A collection has to include at least 1 style") + for st in self.args.styles: + existing_style = database.get_style_by_uuid(st, is_collection=False) + if not existing_style: + existing_style = database.get_style_by_name(st, is_collection=False) + if not existing_style: + raise e.BadRequest(f"A style with name '{st}' cannot be found") + if styles_type is None: + styles_type = existing_style.style_type + elif styles_type != existing_style.style_type: + raise e.BadRequest("Cannot mix image and text styles in the same collection", "StyleMismatch") + self.styles.append(existing_style) + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Collection PATCH") + self.existing_collection = database.get_style_by_uuid(collection_id, is_collection=True) + if not self.existing_collection: + raise e.ThingNotFound("Collection", collection_id) + if self.existing_collection.user_id != self.user.id: + raise e.Forbidden(f"This Collection is not owned by user {self.user.get_unique_alias()}") + if self.existing_collection.style_type != styles_type: + raise e.BadRequest("Cannot mix image and text styles in the same collection", "StyleMismatch") + collection_modified = False + if self.args.name: + self.existing_collection.name = ensure_clean(self.args.name, "collection name") + collection_modified = True + if self.args.info is not None: + self.existing_collection.info = ensure_clean(self.args.info, "style info") + collection_modified = True + if self.args.public is not None: + self.existing_collection.public = self.args.public + collection_modified = True + if len(self.styles) > 0: + self.existing_collection.styles.clear() + for st in self.styles: + self.existing_collection.styles.append(st) + collection_modified = True + if not collection_modified: + return { + "id": self.existing_collection.id, + "message": "OK", + }, 200 + db.session.commit() + return { + "id": self.existing_collection.id, + "message": "OK", + "warnings": self.warnings, + }, 200 + + @api.expect(parsers.apikey_parser) + @api.marshal_with( + models.response_model_simple_response, + code=200, + description="Operation Completed", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def delete(self, collection_id): + """Deletes a style collection.""" + self.args = parsers.apikey_parser.parse_args() + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Collection PATCH") + self.existing_collection = database.get_style_by_uuid(collection_id, is_collection=True) + if not self.existing_collection: + raise e.ThingNotFound("Collection", collection_id) + if self.existing_collection.user_id != self.user.id and not self.user.moderator: + raise e.Forbidden(f"This Collection is not owned by user {self.user.get_unique_alias()}") + if self.existing_collection.user_id != self.user.id and self.user.moderator: + logger.info(f"Moderator {self.user.moderator} deleted collection {self.existing_collection.id}") + self.existing_collection.delete() + return ({"message": "OK"}, 200) + + +class SingleCollectionByName(SingleCollectionGet): + @cache.cached(timeout=30) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_collection, + code=200, + description="Lists collection information by name", + as_list=False, + ) + def get(self, collection_name): + """Seeks an style collection by name and displays its information.""" + self.existing_collection = database.get_style_by_name(collection_name) + if not self.existing_collection: + raise e.ThingNotFound("Collection", collection_name) + return self.existing_collection.get_details() + + +# TODO: vote and transfer kudos on vote diff --git a/horde/apis/v2/kobold.py b/horde/apis/v2/kobold.py index 7e318c29..9b956b3c 100644 --- a/horde/apis/v2/kobold.py +++ b/horde/apis/v2/kobold.py @@ -2,14 +2,26 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later +import random +from collections import defaultdict + from flask import request from flask_restx import Resource, reqparse import horde.apis.limiter_api as lim from horde import exceptions as e from horde.apis.models.kobold_v2 import TextModels, TextParsers -from horde.apis.v2.base import GenerateTemplate, JobPopTemplate, JobSubmitTemplate, api +from horde.apis.v2.base import ( + GenerateTemplate, + JobPopTemplate, + JobSubmitTemplate, + SingleStyleTemplate, + SingleStyleTemplateGet, + StyleTemplate, + api, +) from horde.classes.base import settings +from horde.classes.base.style import Style, StyleCollection from horde.classes.kobold.genstats import ( get_compiled_textgen_stats_models, get_compiled_textgen_stats_totals, @@ -22,7 +34,8 @@ from horde.limiter import limiter from horde.logger import logger from horde.model_reference import model_reference -from horde.utils import hash_dictionary +from horde.utils import ensure_clean, hash_dictionary +from horde.validation import ParamValidator from horde.vars import horde_title models = TextModels(api) @@ -64,10 +77,11 @@ def post(self): self.args = parsers.generate_parser.parse_args() try: super().post() - except KeyError: + except KeyError as e: logger.error("caught missing Key.") logger.error(self.args) logger.error(self.args.params) + raise e return {"message": "Internal Server Error"}, 500 if self.args.dry_run: ret_dict = {"kudos": round(self.kudos)} @@ -84,11 +98,12 @@ def initiate_waiting_prompt(self): self.wp = TextWaitingPrompt( worker_ids=self.workers, models=self.models, - prompt=self.args.prompt, + prompt=self.prompt, user_id=self.user.id, params=self.params, softprompt=self.args.softprompt, trusted_workers=self.args.trusted_workers, + validated_backends=self.args.validated_backends, worker_blacklist=self.args.worker_blacklist, slow_workers=self.args.slow_workers, ipaddr=self.user_ip, @@ -152,25 +167,13 @@ def get_size_too_big_message(self): ) def validate(self): + self.prompt = self.args.prompt + self.apply_style() super().validate() - if self.params.get("max_context_length", 1024) < self.params.get("max_length", 80): - raise e.BadRequest("You cannot request more tokens than your context length.", rc="TokenOverflow") - if "sampler_order" in self.params and len(set(self.params["sampler_order"])) < 7: - raise e.BadRequest( - "When sending a custom sampler order, you need to specify all possible samplers in the order", - rc="MissingFullSamplerOrder", - ) + param_validator = ParamValidator(self.prompt, self.args.models, self.params, self.user) + self.warnings = param_validator.validate_text_params() if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0: raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.") - if "stop_sequence" in self.params: - stop_seqs = set(self.params["stop_sequence"]) - if len(stop_seqs) > 128: - raise e.BadRequest("Too many stop sequences specified (max allowed is 128).", rc="TooManyStopSequences") - total_stop_seq_len = 0 - for seq in stop_seqs: - total_stop_seq_len += len(seq) - if total_stop_seq_len > 2000: - raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).", rc="ExcessiveStopSequence") def get_hashed_params_dict(self): gen_payload = self.params.copy() @@ -181,6 +184,36 @@ def get_hashed_params_dict(self): # logger.debug([params_hash,gen_payload]) return params_hash + def apply_style(self): + if self.args.style is None: + return + self.existing_style = database.get_style_by_uuid(self.args.style) + if not self.existing_style: + self.existing_style = database.get_style_by_name(self.args.style) + if not self.existing_style: + raise e.ThingNotFound("Style", self.args.style) + if self.existing_style.style_type != "text": + raise e.BadRequest("Image styles cannot be used on image requests", "StyleMismatch") + if isinstance(self.existing_style, StyleCollection): + colstyles = self.existing_style.styles + random.shuffle(colstyles) + self.existing_style.use_count += 1 + self.existing_style = colstyles[0] + self.models = self.existing_style.get_model_names() + # We need to use defaultdict to avoid getting keyerrors in case the style author added + # Erroneous keys in the string + self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.prompt)) + requested_n = self.params.get("n", 1) + self.params = self.existing_style.params + self.params["n"] = requested_n + self.nsfw = self.existing_style.nsfw + self.existing_style.use_count += 1 + if self.existing_style.user != self.user: + self.existing_style.user.record_style(2, "text") + self.style_kudos = True + db.session.commit() + logger.debug(f"Style '{self.args.style}' applied.") + class TextAsyncStatus(Resource): get_parser = reqparse.RequestParser() @@ -311,7 +344,6 @@ def get_sorted_wp(self, priority_user_ids=None): priority_user_ids=priority_user_ids, page=self.wp_page, ) - return sorted_wps @@ -409,3 +441,235 @@ def post(self, user_id=""): user.set_trusted(self.args.trusted) user.modify_kudos(self.args.kudos_amount, "koboldai") return {"new_kudos": user.kudos}, 200 + + +## Styles +class TextStyle(StyleTemplate): + gentype = "text" + + get_parser = reqparse.RequestParser() + get_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version.", + location="headers", + ) + get_parser.add_argument( + "sort", + required=False, + default="popular", + type=str, + help="How to sort returned styles. 'popular' sorts by usage and 'age' sorts by date added.", + location="args", + ) + get_parser.add_argument( + "page", + required=False, + default=1, + type=int, + help="Which page of results to return. Each page has 25 styles.", + location="args", + ) + get_parser.add_argument( + "tag", + required=False, + type=str, + help="If included, will only return styles with this tag", + location="args", + ) + get_parser.add_argument( + "model", + required=False, + type=str, + help="If included, will only return styles using this model", + location="args", + ) + + @logger.catch(reraise=True) + @cache.cached(timeout=1, query_string=True) + @api.expect(get_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists text styles information", + as_list=True, + ) + def get(self): + """Retrieves information about all text styles + Can be filtered based on model or tags + """ + self.args = self.get_parser.parse_args() + return super().get() + + decorators = [ + limiter.limit( + limit_value="20/hour", + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(parsers.style_parser, models.input_model_style, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Added", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def post(self): + """Creates a new text style""" + self.params = {} + self.warnings = set() + self.args = parsers.style_parser.parse_args() + if self.args.params: + self.params = self.args.params + self.models = [] + if self.args.models is not None: + self.models = self.args.models.copy() + if len(self.models) > 5: + raise e.BadRequest("A style can only use a maximum of 5 models.") + if len(self.models) < 1: + raise e.BadRequest("A style has to specify at least one model.") + else: + raise e.BadRequest("A style has to specify at least one model.") + self.tags = [] + if self.args.tags is not None: + self.tags = self.args.tags.copy() + if len(self.tags) > 10: + raise e.BadRequest("A style can be tagged a maximum of 10 times.") + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("TextStyle POST") + if not self.user.customizer and not self.user.trusted: + raise e.Forbidden( + "Only customizers and trusted users can create new styles. You can request the customizer role in our channels.", + rc="StylesRequiresCustomizer", + ) + if self.user.is_anon(): + raise e.Forbidden("Anonymous users cannot create styles", rc="StylesAnonForbidden") + self.style_name = ensure_clean(self.args.name, "style name") + self.validate() + new_style = Style( + user_id=self.user.id, + style_type=self.gentype, + info=ensure_clean(self.args.info, "style info") if self.args.info is not None else "", + name=self.style_name, + public=self.args.public, + nsfw=self.args.nsfw, + prompt=self.args.prompt, + params=self.args.params if self.args.params is not None else {}, + ) + new_style.create() + new_style.set_models(self.models) + new_style.set_tags(self.tags) + return { + "id": new_style.id, + "message": "OK", + "warnings": self.warnings, + }, 200 + + def validate(self): + if database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}"): + raise e.BadRequest( + ( + f"Style with name '{self.style_name}' already exists for user '{self.user.get_unique_alias()}'." + " Please use PATCH to modify an existing style." + ), + ) + param_validator = ParamValidator(prompt=self.args.prompt, models=self.models, params=self.params, user=self.user) + self.warnings = param_validator.validate_text_params() + param_validator.check_for_special() + param_validator.validate_text_prompt(self.args.prompt) + + +class SingleTextStyle(SingleStyleTemplate): + gentype = "text" + + @cache.cached(timeout=30) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists text styles information", + as_list=False, + ) + def get(self, style_id): + """Displays information about a single text style.""" + return super().get_through_id(style_id) + + decorators = [ + limiter.limit( + limit_value=lim.get_request_90min_limit_per_ip, + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(parsers.style_parser_patch, models.patch_model_style, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Updated", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def patch(self, style_id): + """Modifies a text style.""" + return super().patch(style_id) + + def validate(self): + if ( + self.style_name is not None + and database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}") + and self.existing_style.name != self.style_name + ): + raise e.BadRequest( + ( + f"Style with name '{self.style_name}' already exists for user '{self.user.get_unique_alias()}'." + " Please use a different name if you want to rename." + ), + ) + prompt = self.args.prompt if self.args.prompt is not None else self.existing_style.prompt + models = self.models if len(self.models) > 0 else self.existing_style.get_model_names() + params = self.args.params if self.args.params is not None else self.existing_style.params + param_validator = ParamValidator(prompt=prompt, models=models, params=params, user=self.user) + self.warnings = param_validator.validate_text_params() + param_validator.check_for_special() + param_validator.validate_text_prompt(prompt) + + @api.expect(parsers.apikey_parser) + @api.marshal_with( + models.response_model_simple_response, + code=200, + description="Style Deleted", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def delete(self, style_id): + """Deletes a text style.""" + return super().delete(style_id) + + +class SingleImageStyleByName(SingleStyleTemplateGet): + gentype = "text" + + @cache.cached(timeout=30) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists text style information by name", + as_list=False, + ) + def get(self, style_name): + """Seeks a text style by name and displays its information.""" + self.existing_style = database.get_style_by_name(style_name, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_name) + return super().get_existing_style() diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 6f0fd02b..62610635 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later +import random +from collections import defaultdict from datetime import datetime import requests @@ -12,8 +14,17 @@ import horde.classes.base.stats as stats from horde import exceptions as e from horde.apis.models.stable_v2 import ImageModels, ImageParsers -from horde.apis.v2.base import GenerateTemplate, JobPopTemplate, JobSubmitTemplate, api +from horde.apis.v2.base import ( + GenerateTemplate, + JobPopTemplate, + JobSubmitTemplate, + SingleStyleTemplate, + SingleStyleTemplateGet, + StyleTemplate, + api, +) from horde.classes.base import settings +from horde.classes.base.style import Style, StyleCollection, StyleExample from horde.classes.base.user import User from horde.classes.stable.genstats import ( get_compiled_imagegen_stats_models, @@ -23,7 +34,6 @@ from horde.classes.stable.interrogation_worker import InterrogationWorker from horde.classes.stable.waiting_prompt import ImageWaitingPrompt from horde.classes.stable.worker import ImageWorker -from horde.consts import KNOWN_POST_PROCESSORS, KNOWN_UPSCALERS from horde.countermeasures import CounterMeasures from horde.database import functions as database from horde.enums import WarningMessage @@ -33,7 +43,8 @@ from horde.logger import logger from horde.model_reference import model_reference from horde.patreon import patrons -from horde.utils import does_extra_text_reference_exist, hash_dictionary +from horde.utils import does_extra_text_reference_exist, ensure_clean, hash_dictionary +from horde.validation import ParamValidator from horde.vars import horde_title models = ImageModels(api) @@ -104,7 +115,12 @@ def get_size_too_big_message(self): ) def validate(self): + self.prompt = self.args.prompt + self.apply_style() super().validate() + param_validator = ParamValidator(prompt=self.prompt, models=self.args.models, params=self.params, user=self.user) + self.warnings = param_validator.validate_image_params() + param_validator.check_for_special() # During raids, we prevent VPNs if settings.mode_raid() and not self.user.trusted and not patrons.is_patron(self.user.id): self.safe_ip = CounterMeasures.is_ip_safe(self.user_ip) @@ -116,48 +132,12 @@ def validate(self): raise e.NotTrusted(rc="UntrustedUnsafeIP") if not self.user.special and self.params.get("special"): raise e.BadRequest("Only special users can send a special field.", "SpecialFieldNeedsSpecialUser") - for model in self.args.models: - if "horde_special" in model: - if not self.user.special: - raise e.Forbidden("Only special users can request a special model.", "SpecialModelNeedsSpecialUser") - usermodel = model.split("::") - if len(usermodel) == 1: - raise e.BadRequest( - "Special models must always include the username, in the form of 'horde_special::user#id'", - rc="SpecialMissingUsername", - ) - user_alias = usermodel[1] - if self.user.get_unique_alias() != user_alias: - raise e.Forbidden(f"This model can only be requested by {user_alias}", "SpecialForbidden") - if not self.params.get("special"): - raise e.BadRequest("Special models have to include a special payload", rc="SpecialMissingPayload") if not self.args.source_image and self.args.source_mask: raise e.SourceMaskUnnecessary if self.params.get("control_type") in ["normal", "mlsd", "hough"] and any( model_reference.get_model_baseline(model_name).startswith("stable diffusion 2") for model_name in self.args.models ): raise e.UnsupportedModel("No current model available for this particular ControlNet for SD2.x", rc="ControlNetUnsupported") - for model_req_dict in [model_reference.get_model_requirements(m) for m in self.args.models]: - if "clip_skip" in model_req_dict and model_req_dict["clip_skip"] != self.params.get("clip_skip", 1): - self.warnings.add(WarningMessage.ClipSkipMismatch) - if "min_steps" in model_req_dict and model_req_dict["min_steps"] > self.params.get("steps", 30): - self.warnings.add(WarningMessage.StepsTooFew) - if "max_steps" in model_req_dict and model_req_dict["max_steps"] < self.params.get("steps", 30): - self.warnings.add(WarningMessage.StepsTooMany) - if "cfg_scale" in model_req_dict and model_req_dict["cfg_scale"] != self.params.get("cfg_scale", 7.5): - self.warnings.add(WarningMessage.CfgScaleMismatch) - if "min_cfg_scale" in model_req_dict and model_req_dict["min_cfg_scale"] > self.params.get("cfg_scale", 7.5): - self.warnings.add(WarningMessage.CfgScaleTooSmall) - if "max_cfg_scale" in model_req_dict and model_req_dict["max_cfg_scale"] < self.params.get("cfg_scale", 7.5): - self.warnings.add(WarningMessage.CfgScaleTooLarge) - if "samplers" in model_req_dict and self.params.get("sampler_name", "k_euler_a") not in model_req_dict["samplers"]: - self.warnings.add(WarningMessage.SamplerMismatch) - # FIXME: Scheduler workaround until we support multiple schedulers - scheduler = "karras" - if not self.params.get("karras", True): - scheduler = "simple" - if "schedulers" in model_req_dict and scheduler not in model_req_dict["schedulers"]: - self.warnings.add(WarningMessage.SchedulerMismatch) if "control_type" in self.params and any(model_name in ["pix2pix"] for model_name in self.args.models): raise e.UnsupportedModel("You cannot use ControlNet with these models.", rc="ControlNetUnsupported") # if self.params.get("image_is_control"): @@ -172,23 +152,10 @@ def validate(self): if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models): if "control_type" in self.params: raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch") - if "loras" in self.params: - if len(self.params["loras"]) > 5: - raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras") - for lora in self.params["loras"]: - if lora.get("is_version") and not lora["name"].isdigit(): - raise e.BadRequest("explicit LoRa version requests have to be a version ID (i.e integer).", rc="BadLoraVersion") - if "tis" in self.params and len(self.params["tis"]) > 20: - raise e.BadRequest("You cannot request more than 20 Textual Inversions per generation.", rc="TooManyTIs") + if any(model_reference.get_model_baseline(model_name).startswith("flux_1") for model_name in self.args.models): + if "control_type" in self.params: + raise e.BadRequest("ControlNet does not work with Flux currently.", rc="ControlNetMismatch") if self.params.get("transparent", False) is True: - if any( - model_reference.get_model_baseline(model_name) not in ["stable_diffusion_xl", "stable diffusion 1"] - for model_name in self.args.models - ): - raise e.BadRequest( - "Generating Transparent images is only possible for Stable Diffusion 1.5 and XL models.", - rc="InvalidTransparencyModel", - ) if self.args.extra_source_images and len(self.args.extra_source_images) > 0: raise e.BadRequest( "Generating Transparent images is not supported during img2img workflows.", @@ -212,11 +179,6 @@ def validate(self): if self.params.get("workflow") == "qr_code": # QR-code pipeline cannot do batching currently self.args["disable_batching"] = True - if not all( - model_reference.get_model_baseline(model_name) in ["stable diffusion 1", "stable_diffusion_xl"] - for model_name in self.args.models - ): - raise e.BadRequest("QR Code controlnet only works with SD 1.5 and SDXL models currently", rc="ControlNetMismatch.") if self.params.get("extra_texts") is None or len(self.params.get("extra_texts")) == 0: raise e.BadRequest("This request requires you pass the required extra texts for this workflow.", rc="MissingExtraTexts.") if not does_extra_text_reference_exist(self.params.get("extra_texts"), "qr_code"): @@ -241,26 +203,6 @@ def validate(self): self.params["n"] = 2 # if any(model_name.startswith("stable_diffusion_2") for model_name in self.args.models): # raise e.UnsupportedModel - if len(self.args["prompt"].split()) > 7500: - raise e.InvalidPromptSize(self.username) - if any(model_name in KNOWN_POST_PROCESSORS for model_name in self.args.models): - raise e.UnsupportedModel(rc="UnexpectedModelName") - if self.args.params: - upscaler_count = len([pp for pp in self.args.params.get("post_processing", []) if pp in KNOWN_UPSCALERS]) - if upscaler_count > 1: - raise e.BadRequest("Cannot use more than 1 upscaler at a time.", rc="TooManyUpscalers") - - cfg_scale = self.args.params.get("cfg_scale") - if cfg_scale is not None: - try: - rounded_cfg_scale = round(cfg_scale, 2) - if rounded_cfg_scale != cfg_scale: - raise e.BadRequest("cfg_scale must be rounded to 2 decimal places", rc="BadCFGDecimals") - except (TypeError, ValueError): - logger.warning( - f"Invalid cfg_scale: {cfg_scale} for user {self.username} when it should be already validated.", - ) - raise e.BadRequest("cfg_scale must be a valid number", rc="BadCFGNumber") if self.args["Client-Agent"] in ["My-Project:v0.0.1:My-Contact"]: raise e.Forbidden( @@ -286,14 +228,16 @@ def initiate_waiting_prompt(self): self.wp = ImageWaitingPrompt( worker_ids=self.workers, models=self.models, - prompt=self.args.prompt, + prompt=self.prompt, user_id=self.user.id, params=self.params, nsfw=self.args.nsfw, censor_nsfw=self.args.censor_nsfw, trusted_workers=self.args.trusted_workers, + validated_backends=self.args.validated_backends, worker_blacklist=self.args.worker_blacklist, slow_workers=self.args.slow_workers, + extra_slow_workers=self.args.extra_slow_workers, source_processing=self.args.source_processing, ipaddr=self.user_ip, safe_ip=self.safe_ip, @@ -413,8 +357,45 @@ def activate_waiting_prompt(self): source_image=self.source_image, source_mask=self.source_mask, extra_source_images=self.args.extra_source_images, + kudos_adjustment=2 if self.style_kudos is not None else 0, ) + def apply_style(self): + if self.args.style is None: + return + self.existing_style = database.get_style_by_uuid(self.args.style) + if not self.existing_style: + self.existing_style = database.get_style_by_name(self.args.style) + if not self.existing_style: + raise e.ThingNotFound("Style", self.args.style) + if self.existing_style.style_type != "image": + raise e.BadRequest("Text styles cannot be used on image requests", "StyleMismatch") + if isinstance(self.existing_style, StyleCollection): + colstyles = self.existing_style.styles + random.shuffle(colstyles) + self.existing_style = colstyles[0] + self.existing_style.use_count += 1 + self.models = self.existing_style.get_model_names() + self.negprompt = "" + if "###" in self.prompt: + self.prompt, self.negprompt = self.prompt.split("###", 1) + if "###" not in self.existing_style.prompt and self.negprompt != "" and "###" not in self.negprompt: + self.negprompt = "###" + self.negprompt + # We need to use defaultdict to avoid getting keyerrors in case the style author added + # Erroneous keys in the string + self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.prompt, np=self.negprompt)) + requested_n = self.params.get("n", 1) + self.params = self.existing_style.params + self.params["n"] = requested_n + self.nsfw = self.existing_style.nsfw + self.existing_style.use_count += 1 + # We don't reward kudos to ourselves + if self.existing_style.user != self.user: + self.existing_style.user.record_style(2, "image") + self.style_kudos = True + db.session.commit() + logger.debug(f"Style '{self.args.style}' applied.") + class ImageAsyncStatus(Resource): get_parser = reqparse.RequestParser() @@ -429,7 +410,6 @@ class ImageAsyncStatus(Resource): decorators = [limiter.limit("10/minute", key_func=lim.get_request_path)] - # If I marshal it here, it overrides the marshalling of the child class unfortunately @api.expect(get_parser) @api.marshal_with( models.response_model_wp_status_full, @@ -593,6 +573,10 @@ def post(self): db_skipped["kudos"] = post_ret["skipped"]["kudos"] if "blacklist" in post_ret.get("skipped", {}): db_skipped["blacklist"] = post_ret["skipped"]["blacklist"] + if "step_count" in post_ret.get("skipped", {}): + db_skipped["step_count"] = post_ret["skipped"]["step_count"] + if "bridge_version" in post_ret.get("skipped", {}): + db_skipped["bridge_version"] = db_skipped.get("bridge_version", 0) + post_ret["skipped"]["bridge_version"] post_ret["skipped"] = db_skipped # logger.debug(post_ret) return post_ret, retcode @@ -615,6 +599,8 @@ def check_in(self): allow_controlnet=self.args.allow_controlnet, allow_sdxl_controlnet=self.args.allow_sdxl_controlnet, allow_lora=self.args.allow_lora, + extra_slow_worker=self.args.extra_slow_worker, + limit_max_steps=self.args.limit_max_steps, priority_usernames=self.priority_usernames, ) @@ -1335,3 +1321,435 @@ def get(self): if self.args.model_state not in ["known", "custom", "all"]: raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']") return get_compiled_imagegen_stats_models(self.args.model_state), 200 + + +## Styles +class ImageStyle(StyleTemplate): + gentype = "image" + + get_parser = reqparse.RequestParser() + get_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version.", + location="headers", + ) + get_parser.add_argument( + "sort", + required=False, + default="popular", + type=str, + help="How to sort returned styles. 'popular' sorts by usage and 'age' sorts by date added.", + location="args", + ) + get_parser.add_argument( + "page", + required=False, + default=1, + type=int, + help="Which page of results to return. Each page has 25 styles.", + location="args", + ) + get_parser.add_argument( + "tag", + required=False, + type=str, + help="If included, will only return styles with this tag", + location="args", + ) + get_parser.add_argument( + "model", + required=False, + type=str, + help="If included, will only return styles using this model", + location="args", + ) + + @cache.cached(timeout=30, query_string=True) + @api.expect(get_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists image styles information", + as_list=True, + ) + def get(self): + """Retrieves information about all image styles + Can be filtered based on model or tags + """ + self.args = self.get_parser.parse_args() + return super().get() + + decorators = [ + limiter.limit( + limit_value=lim.get_request_90min_limit_per_ip, + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(parsers.style_parser, models.input_model_style, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Added", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def post(self): + """Creates a new image style.""" + self.params = {} + self.warnings = set() + self.args = parsers.style_parser.parse_args() + if self.args.params: + self.params = self.args.params + self.models = [] + if self.args.models is not None: + self.models = self.args.models.copy() + if len(self.models) > 5: + raise e.BadRequest("A style can only use a maximum of 5 models.") + if len(self.models) < 1: + raise e.BadRequest("A style has to specify at least one model.") + else: + raise e.BadRequest("A style has to specify at least one model.") + self.tags = [] + if self.args.tags is not None: + self.tags = self.args.tags.copy() + if len(self.tags) > 10: + raise e.BadRequest("A style can be tagged a maximum of 10 times.") + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("ImageStyle POST") + if not self.user.customizer and not self.user.trusted: + raise e.Forbidden( + "Only customizers and trusted users can create new styles. You can request the customizer role in our channels.", + rc="StylesRequiresCustomizer", + ) + if self.user.is_anon(): + raise e.Forbidden("Anonymous users cannot create styles", rc="StylesAnonForbidden") + self.style_name = ensure_clean(self.args.name, "style name") + self.validate() + new_style = Style( + user_id=self.user.id, + style_type=self.gentype, + info=ensure_clean(self.args.info, "style info") if self.args.info is not None else "", + name=self.style_name, + public=self.args.public, + nsfw=self.args.nsfw, + prompt=self.args.prompt, + params=self.args.params if self.args.params is not None else {}, + ) + new_style.create() + new_style.set_models(self.models) + new_style.set_tags(self.tags) + return { + "id": new_style.id, + "message": "OK", + "warnings": self.warnings, + }, 200 + + def validate(self): + if database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}"): + raise e.BadRequest( + ( + f"Style with name '{self.style_name}' already exists for user '{self.user.get_unique_alias()}'." + " Please use PATCH to modify an existing style." + ), + ) + param_validator = ParamValidator(prompt=self.args.prompt, models=self.models, params=self.params, user=self.user) + self.warnings = param_validator.validate_image_params() + param_validator.check_for_special() + param_validator.validate_image_prompt(self.args.prompt) + + +class SingleImageStyle(SingleStyleTemplate): + gentype = "image" + + @cache.cached(timeout=30) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists image styles information", + as_list=False, + ) + def get(self, style_id): + """Displays information about an image style.""" + return super().get_through_id(style_id) + + decorators = [ + limiter.limit( + limit_value="20/hour", + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(parsers.style_parser_patch, models.patch_model_style, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Updated", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def patch(self, style_id): + """Modifies an image style.""" + return super().patch(style_id) + + def validate(self): + if ( + self.style_name is not None + and database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}") + and self.existing_style.name != self.style_name + ): + raise e.BadRequest( + ( + f"Style with name '{self.style_name}' already exists for user '{self.user.get_unique_alias()}'." + " Please use a different name if you want to rename." + ), + ) + prompt = self.args.prompt if self.args.prompt is not None else self.existing_style.prompt + models = self.models if len(self.models) > 0 else self.existing_style.get_model_names() + params = self.args.params if self.args.params is not None else self.existing_style.params + param_validator = ParamValidator(prompt=prompt, models=models, params=params, user=self.user) + self.warnings = param_validator.validate_image_params() + param_validator.check_for_special() + param_validator.validate_image_prompt(prompt) + + @api.expect(parsers.apikey_parser) + @api.marshal_with( + models.response_model_simple_response, + code=200, + description="Style Deleted", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def delete(self, style_id): + """Deletes an image style.""" + return super().delete(style_id) + + +class SingleImageStyleByName(SingleStyleTemplateGet): + gentype = "image" + + @cache.cached(timeout=30) + @api.expect(parsers.basic_parser) + @api.marshal_with( + models.response_model_style, + code=200, + description="Lists image style information by name", + as_list=False, + ) + def get(self, style_name): + """Seeks an image style by name and displays its information.""" + self.existing_style = database.get_style_by_name(style_name, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_name) + return super().get_existing_style() + + +class ImageStyleExample(Resource): + post_parser = reqparse.RequestParser() + post_parser.add_argument("apikey", type=str, required=True, help="A User API key", location="headers") + post_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version.", + location="headers", + ) + post_parser.add_argument( + "url", + required=True, + default="popular", + type=str, + help="The url where this image is hosted", + location="json", + ) + post_parser.add_argument( + "primary", + required=True, + default=1, + type=bool, + help="Whether this image is meant to be the primary example of this style", + location="json", + ) + + decorators = [ + limiter.limit( + limit_value=lim.get_request_90min_limit_per_ip, + key_func=lim.get_request_path, + ), + limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path), + ] + + @api.expect(post_parser, models.input_model_style_example_post, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Example Added", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def post(self, style_id): + """Creates an image style example.""" + self.args = self.post_parser.parse_args() + if not self.args.url.startswith("https://"): + raise e.BadRequest("The url has to start with 'https://'") + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Style Example POST") + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_id) + if self.existing_style.user_id != self.user.id and not self.user.moderator: + raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}") + if self.existing_style.user_id != self.user.id and self.user.moderator: + logger.info(f"Moderator {self.user.moderator} added example to style {self.existing_style.id}") + if self.existing_style.style_type == "text": + raise e.BadRequest("Cannot add image examples to text styles.") + if len(self.existing_style.examples) >= 4: + raise e.Forbidden("You cannot have more than 4 examples for each style", "TooManyStyleExamples") + previous_primary = None + is_primary = self.args.primary + for example in self.existing_style.examples: + if example.url == self.args.url: + raise e.BadRequest(f"The url '{self.args.url}'is already used for this style.", "ExampleURLAlreadyInUse") + if example.primary is True: + previous_primary = example + if self.args.primary is True: + previous_primary.primary = False + # If we have no primary yet. the first image becomes the default primary. + elif not is_primary and previous_primary is None: + is_primary = True + new_example = StyleExample( + style_id=self.existing_style.id, + url=self.args.url, + primary=is_primary, + ) + db.session.add(new_example) + db.session.commit() + return { + "id": new_example.id, + "message": "OK", + "warnings": {}, + }, 200 + + +class SingleImageStyleExample(Resource): + patch_parser = reqparse.RequestParser() + patch_parser.add_argument("apikey", type=str, required=True, help="A User API key", location="headers") + patch_parser.add_argument( + "Client-Agent", + default="unknown:0:unknown", + type=str, + required=False, + help="The client name and version.", + location="headers", + ) + patch_parser.add_argument( + "url", + required=False, + default="popular", + type=str, + help="The url where this image is hosted", + location="json", + ) + patch_parser.add_argument( + "primary", + required=False, + default=1, + type=bool, + help="Whether this image is meant to be the primary example of this style", + location="json", + ) + + @api.expect(patch_parser, models.input_model_style_example_patch, validate=True) + @api.marshal_with( + models.response_model_styles_post, + code=200, + description="Style Example Modified", + as_list=False, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def patch(self, style_id, example_id): + """Modified an existing image style example.""" + self.args = self.patch_parser.parse_args() + if not self.args.url.startswith("https://"): + raise e.BadRequest("The url has to start with 'https://'") + self.example = db.session.query(StyleExample).filter_by(id=example_id).first() + if not self.example: + raise e.ThingNotFound("Style Example", example_id) + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Style Example PATCH") + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_id) + if self.existing_style.user_id != self.user.id and not self.user.moderator: + raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}") + if self.existing_style.user_id != self.user.id and self.user.moderator: + logger.info(f"Moderator {self.user.moderator} patched style example {self.existing_style.id}") + previous_primary = None + for example in self.existing_style.examples: + if example.id == self.example.id: + continue + if example.url == self.args.url: + raise e.BadRequest(f"The url '{self.args.url}'is already used for this style.", "ExampleURLAlreadyInUse") + if example.primary is True: + previous_primary = example + if not self.args.primary and previous_primary is None: + raise e.BadRequest("You cannot remove the last primary example of this style") + if self.args.primary is True and previous_primary is not None: + previous_primary.primary = False + self.example.primary = self.args.primary + if self.args.url: + self.example.url = self.args.url + db.session.commit() + return { + "id": self.example.id, + "message": "OK", + "warnings": {}, + }, 200 + + @api.expect(parsers.apikey_parser) + @api.marshal_with( + models.response_model_simple_response, + code=200, + description="Style Example Deleted", + skip_none=True, + ) + @api.response(400, "Validation Error", models.response_model_validation_errors) + @api.response(401, "Invalid API Key", models.response_model_error) + def delete(self, style_id, example_id): + """Deletes an image style example.""" + self.args = parsers.apikey_parser.parse_args() + self.example = db.session.query(StyleExample).filter_by(id=example_id).first() + if not self.example: + raise e.ThingNotFound("Style Example", example_id) + self.user = database.find_user_by_api_key(self.args["apikey"]) + if not self.user: + raise e.InvalidAPIKey("Style Example DELETE") + self.existing_style = database.get_style_by_uuid(style_id, is_collection=False) + if not self.existing_style: + raise e.ThingNotFound("Style", style_id) + if self.existing_style.user_id != self.user.id and not self.user.moderator: + raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}") + if self.existing_style.user_id != self.user.id and self.user.moderator: + logger.info(f"Moderator {self.user.moderator} deleted style example {self.existing_style.id}") + if self.example.primary: + for example in self.existing_style.examples: + if example.id != self.example.id: + example.primary = True + db.session.delete(self.example) + db.session.commit() + return ({"message": "OK"}, 200) diff --git a/horde/bridge_reference.py b/horde/bridge_reference.py index 453bd5dc..f25de29f 100644 --- a/horde/bridge_reference.py +++ b/horde/bridge_reference.py @@ -9,6 +9,7 @@ BRIDGE_CAPABILITIES = { "AI Horde Worker reGen": { + 9: {"flux"}, 8: {"layer_diffuse"}, 7: {"qr_code", "extra_texts", "workflow"}, 6: {"stable_cascade_2pass"}, @@ -163,6 +164,14 @@ }, } +LLM_VALIDATED_BACKENDS = { + "AI Horde Worker", + "AI Horde Worker~aphrodite~oai", + "AI Horde Worker~aphrodite~kai", + "KoboldCppEmbedWorker", + "TabbyAPI", +} + @logger.catch(reraise=True) def parse_bridge_agent(bridge_agent): @@ -183,6 +192,7 @@ def parse_bridge_agent(bridge_agent): @logger.catch(reraise=True) def check_bridge_capability(capability, bridge_agent): bridge_name, bridge_version = parse_bridge_agent(bridge_agent) + # logger.debug([bridge_name, bridge_version]) if bridge_name not in BRIDGE_CAPABILITIES: return False total_capabilities = set() @@ -192,9 +202,16 @@ def check_bridge_capability(capability, bridge_agent): if checked_semver.compare(bridge_version) <= 0: total_capabilities.update(BRIDGE_CAPABILITIES[bridge_name][version]) # logger.debug([total_capabilities, capability, capability in total_capabilities]) + # logger.debug([bridge_name, BRIDGE_CAPABILITIES[bridge_name]]) return capability in total_capabilities +@logger.catch(reraise=True) +def is_backed_validated(bridge_agent): + bridge_name, _ = parse_bridge_agent(bridge_agent) + return bridge_name in LLM_VALIDATED_BACKENDS + + @logger.catch(reraise=True) def get_supported_samplers(bridge_agent, karras=True): bridge_name, bridge_version = parse_bridge_agent(bridge_agent) diff --git a/horde/classes/__init__.py b/horde/classes/__init__.py index e83ba1e5..c0c42263 100644 --- a/horde/classes/__init__.py +++ b/horde/classes/__init__.py @@ -4,10 +4,13 @@ from pathlib import Path +from sqlalchemy.sql import text + import horde.classes.base.stats # noqa 401 from horde.argparser import args from horde.classes.base.detection import Filter # noqa 401 from horde.classes.base.settings import HordeSettings +from horde.classes.base.style import Style from horde.classes.base.team import Team # noqa 401 from horde.classes.base.user import User @@ -55,7 +58,7 @@ if file.suffix == ".sql": logger.info(f"Running {file}") with file.open() as f: - db.session.execute(f.read()) + db.session.execute(text(f.read())) db.session.commit() @@ -103,4 +106,5 @@ "HordeSettings", "Filter", "stats", + "Style", ] diff --git a/horde/classes/base/processing_generation.py b/horde/classes/base/processing_generation.py index febdd009..9a95f836 100644 --- a/horde/classes/base/processing_generation.py +++ b/horde/classes/base/processing_generation.py @@ -44,6 +44,7 @@ class ProcessingGeneration(db.Model): nullable=False, server_default=expression.literal(False), ) + job_ttl = db.Column(db.Integer, default=150, nullable=False, index=True) wp_id = db.Column( uuid_column_type(), @@ -80,6 +81,7 @@ def __init__(self, *args, **kwargs): self.model = matching_models[0] else: self.model = kwargs["model"] + self.set_job_ttl() db.session.commit() def set_generation(self, generation, things_per_sec, **kwargs): @@ -163,10 +165,10 @@ def is_completed(self): def is_faulted(self): return self.faulted - def is_stale(self, ttl): + def is_stale(self): if self.is_completed() or self.is_faulted(): return False - return (datetime.utcnow() - self.start_time).total_seconds() > ttl + return (datetime.utcnow() - self.start_time).total_seconds() > self.job_ttl def delete(self): db.session.delete(self) @@ -224,3 +226,10 @@ def send_webhook(self, kudos): break except Exception as err: logger.debug(f"Exception when sending generation webhook: {err}. Will retry {3-riter-1} more times...") + + def set_job_ttl(self): + """Returns how many seconds each job request should stay waiting before considering it stale and cancelling it + This function should be overriden by the invididual hordes depending on how the calculating ttl + """ + self.job_ttl = 150 + db.session.commit() diff --git a/horde/classes/base/style.py b/horde/classes/base/style.py new file mode 100644 index 00000000..8c524adf --- /dev/null +++ b/horde/classes/base/style.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis +# +# SPDX-License-Identifier: AGPL-3.0-or-later +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import JSON, Table, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import Mapped +from sqlalchemy.sql import expression + +from horde.flask import SQLITE_MODE, db +from horde.logger import logger +from horde.utils import ensure_clean, get_db_uuid + +json_column_type = JSONB if not SQLITE_MODE else JSON +uuid_column_type = lambda: UUID(as_uuid=True) if not SQLITE_MODE else db.String(36) # FIXME # noqa E731 + + +style_collection_mapping = Table( + "style_collection_mapping", + db.Model.metadata, + db.Column("style_id", db.ForeignKey("styles.id", ondelete="CASCADE"), primary_key=True), + db.Column("collection_id", db.ForeignKey("style_collections.id", ondelete="CASCADE"), primary_key=True), +) + + +class StyleCollection(db.Model): + __tablename__ = "style_collections" + __table_args__ = ( + UniqueConstraint( + "user_id", + "name", + name="user_id_name", + ), + ) + id = db.Column(uuid_column_type(), primary_key=True, default=get_db_uuid) + style_type = db.Column(db.String(30), nullable=False, index=True) + info = db.Column(db.String(1000), default="") + name = db.Column(db.String(100), default="", unique=False, nullable=False, index=True) + use_count = db.Column(db.Integer, default=0, nullable=False, server_default=expression.literal(0), index=True) + public = db.Column(db.Boolean, default=False, nullable=False) + + created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + user = db.relationship("User", back_populates="style_collections") + styles: Mapped[list[Style]] = db.relationship(secondary="style_collection_mapping", back_populates="collections") + + def create(self, styles): + for st in styles: + self.styles.append(st) + db.session.add(self) + db.session.commit() + + # Should be extended by each specific horde + @logger.catch(reraise=True) + def get_details(self, details_privilege=0): + """We display these in the collections list json""" + ret_dict = { + "name": self.name, + "id": self.id, + "creator": self.user.get_unique_alias(), + "use_count": self.use_count, + "public": self.public, + "type": self.style_type, + } + styles_array = [] + for s in self.styles: + styles_array.append( + { + "name": s.get_unique_name(), + "id": str(s.id), + }, + ) + ret_dict["styles"] = styles_array + return ret_dict + + def get_model_names(self): + return [m.model for m in self.models] + + def delete(self): + db.session.delete(self) + db.session.commit() + + +class StyleTag(db.Model): + __tablename__ = "style_tags" + id = db.Column(db.Integer, primary_key=True) + style_id = db.Column( + uuid_column_type(), + db.ForeignKey("styles.id", ondelete="CASCADE"), + nullable=False, + ) + style = db.relationship("Style", back_populates="tags") + tag = db.Column(db.String(255), nullable=False, index=True) + + +class StyleModel(db.Model): + __tablename__ = "style_models" + id = db.Column(db.Integer, primary_key=True) + style_id = db.Column( + uuid_column_type(), + db.ForeignKey("styles.id", ondelete="CASCADE"), + nullable=False, + ) + style = db.relationship("Style", back_populates="models") + model = db.Column(db.String(255), nullable=False, index=True) + + +class StyleExample(db.Model): + __tablename__ = "style_examples" + id = db.Column(uuid_column_type(), primary_key=True, default=get_db_uuid) + style_id = db.Column( + uuid_column_type(), + db.ForeignKey("styles.id", ondelete="CASCADE"), + nullable=False, + ) + style = db.relationship("Style", back_populates="examples") + url = db.Column(db.Text, nullable=False, index=True) + primary = db.Column(db.Boolean, default=False, nullable=False) + + +class Style(db.Model): + __tablename__ = "styles" + __table_args__ = ( + UniqueConstraint( + "user_id", + "name", + name="style_user_id_name", + ), + ) + id = db.Column(uuid_column_type(), primary_key=True, default=get_db_uuid) + style_type = db.Column(db.String(30), nullable=False, index=True) + info = db.Column(db.String(1000), nullable=True) + showcase = db.Column(db.String(1000), nullable=True) + name = db.Column(db.String(100), unique=False, nullable=False, index=True) + public = db.Column(db.Boolean, default=False, nullable=False) + nsfw = db.Column(db.Boolean, default=False, nullable=False) + prompt = db.Column(db.Text, nullable=False) + params = db.Column(MutableDict.as_mutable(json_column_type), default={}, nullable=False) + + use_count = db.Column(db.Integer, default=0, nullable=False, server_default=expression.literal(0), index=True) + votes = db.Column(db.Integer, default=0, nullable=False, server_default=expression.literal(0), index=True) + + created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, onupdate=datetime.utcnow) + + user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + user = db.relationship("User", back_populates="styles") + collections: Mapped[list[StyleCollection]] = db.relationship(secondary="style_collection_mapping", back_populates="styles") + models = db.relationship("StyleModel", back_populates="style", cascade="all, delete-orphan") + tags = db.relationship("StyleTag", back_populates="style", cascade="all, delete-orphan") + examples = db.relationship("StyleExample", back_populates="style", cascade="all, delete-orphan") + + def create(self): + db.session.add(self) + db.session.commit() + + def set_name(self, new_name): + if self.name == new_name: + return "OK" + self.name = ensure_clean(new_name, "style name") + db.session.commit() + return "OK" + + def set_info(self, new_info): + if self.info == new_info: + return "OK" + self.info = ensure_clean(new_info, "style info") + db.session.commit() + return "OK" + + def delete(self): + db.session.delete(self) + db.session.commit() + + def record_usage(self): + self.uses += 1 + db.session.commit() + + def record_contribution(self, contributions, kudos): + self.contributions = round(self.contributions + contributions, 2) + self.fulfilments += 1 + self.kudos = round(self.kudos + kudos, 2) + self.last_active = datetime.utcnow() + db.session.commit() + + # Should be extended by each specific horde + @logger.catch(reraise=True) + def get_details(self, details_privilege=0): + """We display these in the styles list json""" + ret_dict = { + "name": self.name, + "info": self.info, + "id": self.id, + "params": self.params, + "prompt": self.prompt, + "tags": self.get_tag_names(), + "models": self.get_model_names(), + "examples": self.examples, + "creator": self.user.get_unique_alias(), + "use_count": self.use_count, + "public": self.public, + "nsfw": self.nsfw, + } + return ret_dict + + def get_model_names(self): + return [m.model for m in self.models] + + def get_tag_names(self): + return [t.tag for t in self.tags] + + def parse_tags(self, tags): + """Parses the tags provided for the style into a set""" + tags = [ensure_clean(tag[0:100], "style tag") for tag in tags] + del tags[10:] + return set(tags) + + def parse_models(self, models): + """Parses the models provided for the style into a set""" + models = [ensure_clean(model_name[0:100], "style model") for model_name in models] + del models[5:] + return set(models) + + def set_models(self, models): + models = self.parse_models(models) + existing_model_names = set(self.get_model_names()) + if existing_model_names == models: + return + db.session.query(StyleModel).filter_by(style_id=self.id).delete() + db.session.flush() + for model_name in models: + model = StyleModel(style_id=self.id, model=model_name) + db.session.add(model) + db.session.commit() + + def set_tags(self, tags): + tags = self.parse_tags(tags) + existing_tags = set(self.get_tag_names()) + if existing_tags == tags: + return + db.session.query(StyleTag).filter_by(style_id=self.id).delete() + db.session.flush() + for tag_name in tags: + tag = StyleTag(style_id=self.id, tag=tag_name) + db.session.add(tag) + db.session.commit() + + def get_unique_name(self): + return f"{self.user.get_unique_alias()}::style::{self.name}" diff --git a/horde/classes/base/user.py b/horde/classes/base/user.py index 618eeb77..36a06fff 100644 --- a/horde/classes/base/user.py +++ b/horde/classes/base/user.py @@ -12,12 +12,12 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.hybrid import hybrid_property -from horde import horde_redis as hr from horde import vars as hv from horde.countermeasures import CounterMeasures from horde.discord import send_problem_user_notification from horde.enums import UserRecordTypes, UserRoleTypes from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.patreon import patrons from horde.suspicions import SUSPICION_LOGS, Suspicions @@ -251,6 +251,8 @@ class User(db.Model): workers = db.relationship("Worker", back_populates="user", cascade="all, delete-orphan") teams = db.relationship("Team", back_populates="owner", cascade="all, delete-orphan") + styles = db.relationship("Style", back_populates="user", cascade="all, delete-orphan") + style_collections = db.relationship("StyleCollection", back_populates="user", cascade="all, delete-orphan") sharedkeys = db.relationship("UserSharedKey", back_populates="user", cascade="all, delete-orphan") suspicions = db.relationship("UserSuspicions", back_populates="user", cascade="all, delete-orphan") records = db.relationship("UserRecords", back_populates="user", cascade="all, delete-orphan") @@ -612,6 +614,14 @@ def record_uptime(self, kudos, bypass_eval=False): else: self.modify_kudos(kudos, "accumulated") + def record_style(self, kudos, contrib_type): + self.update_user_record( + record_type=UserRecordTypes.STYLE, + record=contrib_type, + increment_value=1, + ) + self.modify_kudos(kudos, "styled") + def check_for_trust(self): """After a user passes the evaluation threshold (?? kudos) All the evaluating Kudos added to their total and they automatically become trusted @@ -863,6 +873,17 @@ def get_details(self, details_privilege=0): # unnecessary information, since the workers themselves wil be visible # "public_workers": self.public_workers, } + styles_array = [] + for s in self.styles: + if s.public or details_privilege >= 1: + styles_array.append( + { + "name": s.get_unique_name(), + "id": str(s.id), + "type": str(s.style_type), + }, + ) + ret_dict["styles"] = styles_array if self.public_workers or details_privilege >= 1: workers_array = [] for worker in self.workers: @@ -880,6 +901,9 @@ def get_details(self, details_privilege=0): for wp in self.waiting_prompts: if wp.wp_type not in ret_dict["active_generations"]: ret_dict["active_generations"][wp.wp_type] = [] + # We don't return anon list of gens + if self.is_anon(): + break ret_dict["active_generations"][wp.wp_type].append(str(wp.id)) if details_privilege >= 2: mk_dict = { diff --git a/horde/classes/base/waiting_prompt.py b/horde/classes/base/waiting_prompt.py index 9075c2c9..2b99c6ae 100644 --- a/horde/classes/base/waiting_prompt.py +++ b/horde/classes/base/waiting_prompt.py @@ -11,15 +11,15 @@ from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.sql import expression -from horde import horde_redis as hr from horde import vars as hv from horde.bridge_reference import check_bridge_capability from horde.classes.base.processing_generation import ProcessingGeneration from horde.classes.kobold.processing_generation import TextProcessingGeneration from horde.classes.stable.processing_generation import ImageProcessingGeneration from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger -from horde.utils import get_db_uuid, get_expiry_date +from horde.utils import get_db_uuid, get_expiry_date, get_extra_slow_expiry_date procgen_classes = { "template": ProcessingGeneration, @@ -91,7 +91,9 @@ class WaitingPrompt(db.Model): ipaddr = db.Column(db.String(39)) # ipv6 safe_ip = db.Column(db.Boolean, default=False, nullable=False) trusted_workers = db.Column(db.Boolean, default=False, nullable=False, index=True) + validated_backends = db.Column(db.Boolean, default=True, nullable=False, index=True) slow_workers = db.Column(db.Boolean, default=True, nullable=False, index=True) + extra_slow_workers = db.Column(db.Boolean, default=False, nullable=False, index=True) worker_blacklist = db.Column(db.Boolean, default=False, nullable=False, index=True) faulted = db.Column(db.Boolean, default=False, nullable=False, index=True) active = db.Column(db.Boolean, default=False, nullable=False, index=True) @@ -104,6 +106,7 @@ class WaitingPrompt(db.Model): things = db.Column(db.BigInteger, default=0, nullable=False) total_usage = db.Column(db.Float, default=0, nullable=False) extra_priority = db.Column(db.Integer, default=0, nullable=False, index=True) + # TODO: Delete. Obsoleted. job_ttl = db.Column(db.Integer, default=150, nullable=False) disable_batching = db.Column(db.Boolean, default=False, nullable=False) webhook = db.Column(db.String(1024)) @@ -162,7 +165,7 @@ def set_models(self, model_names=None): model_entry = WPModels(model=model, wp_id=self.id) db.session.add(model_entry) - def activate(self, downgrade_wp_priority=False, extra_source_images=None): + def activate(self, downgrade_wp_priority=False, extra_source_images=None, kudos_adjustment=0): """We separate the activation from __init__ as often we want to check if there's a valid worker for it Before we add it to the queue """ @@ -184,6 +187,7 @@ def activate(self, downgrade_wp_priority=False, extra_source_images=None): self.extra_source_images = {"esi": extra_source_images} # Extra source images add more infrastructure costs, which are represented with a kudos tax horde_tax += 5 * len(extra_source_images) + horde_tax += kudos_adjustment self.record_usage(raw_things=0, kudos=horde_tax, usage_type=self.wp_type, avoid_burn=True) # logger.debug(f"wp {self.id} initiated and paying horde tax: {horde_tax}") db.session.commit() @@ -203,7 +207,6 @@ def extract_params(self): self.things = 0 self.total_usage = round(self.things * self.n, 2) self.prepare_job_payload() - self.set_job_ttl() db.session.commit() def prepare_job_payload(self): @@ -240,7 +243,7 @@ def start_generation(self, worker, amount=1): self.n -= safe_amount payload = self.get_job_payload(current_n) # This does a commit as well - self.refresh() + self.refresh(worker) procgen_class = procgen_classes[self.wp_type] gens_list = [] model = None @@ -282,31 +285,47 @@ def get_pop_payload(self, procgen_list, payload): "id": procgen_list[0].id, "model": procgen_list[0].model, "ids": [g.id for g in procgen_list], + "ttl": procgen_list[0].job_ttl, } if self.extra_source_images and check_bridge_capability("extra_source_images", procgen_list[0].worker.bridge_agent): prompt_payload["extra_source_images"] = self.extra_source_images["esi"] return prompt_payload - def is_completed(self): - if self.faulted: - return True - if self.needs_gen(): - return False + def count_finished_jobs(self): procgen_class = procgen_classes[self.wp_type] - finished_procgens = ( + return ( db.session.query(procgen_class.wp_id) .filter( procgen_class.wp_id == self.id, - procgen_class.fake == False, # noqa E712 + procgen_class.fake.is_(False), or_( - procgen_class.faulted == True, # noqa E712 + procgen_class.faulted.is_(True), procgen_class.generation != None, # noqa E712 ), ) .count() ) - if finished_procgens < self.jobs: + + def count_processing_jobs(self): + procgen_class = procgen_classes[self.wp_type] + return ( + db.session.query(procgen_class.wp_id) + .filter( + procgen_class.wp_id == self.id, + procgen_class.fake.is_(False), + procgen_class.faulted.is_(False), + procgen_class.generation.is_(None), + ) + .count() + ) + + def is_completed(self): + if self.faulted: + return True + if self.needs_gen(): + return False + if self.count_finished_jobs() - self.count_processing_jobs() < self.jobs: return False return True @@ -456,8 +475,13 @@ def abort_for_maintenance(self): except Exception as err: logger.warning(f"Error when aborting WP. Skipping: {err}") - def refresh(self): - self.expiry = get_expiry_date() + def refresh(self, worker=None): + if worker is not None and worker.extra_slow_worker is True: + self.expiry = get_extra_slow_expiry_date() + else: + new_expiry = get_expiry_date() + if self.expiry < new_expiry: + self.expiry = new_expiry db.session.commit() def is_stale(self): @@ -468,13 +492,6 @@ def is_stale(self): def get_priority(self): return self.extra_priority - def set_job_ttl(self): - """Returns how many seconds each job request should stay waiting before considering it stale and cancelling it - This function should be overriden by the invididual hordes depending on how the calculating ttl - """ - self.job_ttl = 150 - db.session.commit() - def refresh_worker_cache(self): worker_ids = [worker.worker_id for worker in self.workers] worker_string_ids = [str(worker.worker_id) for worker in self.workers] diff --git a/horde/classes/base/worker.py b/horde/classes/base/worker.py index aecd3669..f5da87f9 100644 --- a/horde/classes/base/worker.py +++ b/horde/classes/base/worker.py @@ -9,11 +9,11 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.hybrid import hybrid_property -from horde import horde_redis as hr from horde import vars as hv from horde.classes.base import settings from horde.discord import send_pause_notification from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.suspicions import SUSPICION_LOGS, Suspicions from horde.utils import get_db_uuid, is_profane, sanitize_string @@ -121,6 +121,7 @@ class WorkerTemplate(db.Model): # Used by all workers to record how much they can pick up to generate # The value of this column is dfferent per worker type max_power = db.Column(db.Integer, default=20, nullable=False) + extra_slow_worker = db.Column(db.Boolean, default=False, nullable=False, index=True) paused = db.Column(db.Boolean, default=False, nullable=False) maintenance = db.Column(db.Boolean, default=False, nullable=False) @@ -154,7 +155,7 @@ def speed(self) -> int: def speed(cls): performance_avg = db.select(func.avg(WorkerPerformance.performance)).where(WorkerPerformance.worker_id == cls.id).label("speed") return db.case( - [(performance_avg == None, 1 * hv.thing_divisors[cls.wtype])], # noqa E712 + (performance_avg == None, 1 * hv.thing_divisors[cls.wtype]), # noqa E712 else_=performance_avg, ) @@ -196,7 +197,7 @@ def report_suspicion(self, amount=1, reason=Suspicions.WORKER_PROFANITY, formats f"Last suspicion log: {reason.name}.\n" f"Total Suspicion {self.get_suspicion()}", ) - db.session.commit() + db.session.flush() def get_suspicion_reasons(self): return set([s.suspicion_id for s in self.suspicions]) @@ -261,10 +262,6 @@ def toggle_paused(self, is_paused_active): # This should be extended by each worker type def check_in(self, **kwargs): - # To avoid excessive commits, - # we only record new changes on the worker every 30 seconds - if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30: - return self.ipaddr = kwargs.get("ipaddr", None) self.bridge_agent = sanitize_string(kwargs.get("bridge_agent", "unknown:0:unknown")) self.threads = kwargs.get("threads", 1) @@ -275,6 +272,10 @@ def check_in(self, **kwargs): self.prioritized_users = kwargs.get("prioritized_users", []) if not kwargs.get("safe_ip", True) and not self.user.trusted: self.report_suspicion(reason=Suspicions.UNSAFE_IP) + # To avoid excessive commits, + # we only record new uptime on the worker every 30 seconds + if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30: + return if not self.is_stale() and not self.paused and not self.maintenance: self.uptime += (datetime.utcnow() - self.last_check_in).total_seconds() # Every 10 minutes of uptime gets 100 kudos rewarded @@ -293,7 +294,6 @@ def check_in(self, **kwargs): # So that they have to stay up at least 10 mins to get uptime kudos self.last_reward_uptime = self.uptime self.last_check_in = datetime.utcnow() - db.session.commit() def get_human_readable_uptime(self): if self.uptime < 60: @@ -472,6 +472,7 @@ def get_details(self, details_privilege=0): ret_dict["suspicious"] = len(self.suspicions) if details_privilege >= 1 or self.user.public_workers: ret_dict["owner"] = self.user.get_unique_alias() + if details_privilege >= 1: ret_dict["ipaddr"] = self.ipaddr ret_dict["contact"] = self.user.contact return ret_dict @@ -511,7 +512,8 @@ def check_in(self, **kwargs): self.set_models(kwargs.get("models")) self.nsfw = kwargs.get("nsfw", True) self.set_blacklist(kwargs.get("blacklist", [])) - db.session.commit() + self.extra_slow_worker = kwargs.get("extra_slow_worker", False) + # Commit should happen on calling extensions def set_blacklist(self, blacklist): # We don't allow more workers to claim they can server more than 50 models atm (to prevent abuse) @@ -527,7 +529,7 @@ def set_blacklist(self, blacklist): for word in blacklist: blacklisted_word = WorkerBlackList(worker_id=self.id, word=word[0:15]) db.session.add(blacklisted_word) - db.session.commit() + db.session.flush() def refresh_model_cache(self): models_list = [m.model for m in self.models] @@ -563,7 +565,7 @@ def set_models(self, models): return # logger.debug([existing_model_names,models, existing_model_names == models]) db.session.query(WorkerModel).filter_by(worker_id=self.id).delete() - db.session.commit() + db.session.flush() for model_name in models: model = WorkerModel(worker_id=self.id, model=model_name) db.session.add(model) diff --git a/horde/classes/kobold/processing_generation.py b/horde/classes/kobold/processing_generation.py index 84bb1ede..a28f8be9 100644 --- a/horde/classes/kobold/processing_generation.py +++ b/horde/classes/kobold/processing_generation.py @@ -6,6 +6,9 @@ import os from horde import vars as hv +from horde.bridge_reference import ( + is_backed_validated, +) from horde.classes.base.processing_generation import ProcessingGeneration from horde.classes.kobold.genstats import record_text_statistic from horde.flask import db @@ -56,7 +59,10 @@ def get_gen_kudos(self): # This is the approximate reward for generating with a 2.7 model at 4bit model_multiplier = model_reference.get_text_model_multiplier(self.model) parameter_bonus = (max(model_multiplier, 13) / 13) ** 0.20 - kudos = self.get_things_count() * parameter_bonus * model_multiplier / 100 + kudos = self.get_things_count() * parameter_bonus * model_multiplier / 125 + # Unvalidated backends have their rewards cut to 30% + if not is_backed_validated(self.worker.bridge_agent): + kudos *= 0.3 return round(kudos * context_multiplier, 2) def log_aborted_generation(self): diff --git a/horde/classes/kobold/waiting_prompt.py b/horde/classes/kobold/waiting_prompt.py index 761c9235..c5a842a6 100644 --- a/horde/classes/kobold/waiting_prompt.py +++ b/horde/classes/kobold/waiting_prompt.py @@ -62,10 +62,10 @@ def prepare_job_payload(self, initial_dict=None): self.gen_payload["n"] = 1 db.session.commit() - def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None): + def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None, kudos_adjustment=0): # We separate the activation from __init__ as often we want to check if there's a valid worker for it # Before we add it to the queue - super().activate(downgrade_wp_priority, extra_source_images=extra_source_images) + super().activate(downgrade_wp_priority, extra_source_images=extra_source_images, kudos_adjustment=kudos_adjustment) proxied_account = "" if self.proxied_account: proxied_account = f":{self.proxied_account}" diff --git a/horde/classes/kobold/worker.py b/horde/classes/kobold/worker.py index 81eccbc3..e30cdd9f 100644 --- a/horde/classes/kobold/worker.py +++ b/horde/classes/kobold/worker.py @@ -8,9 +8,12 @@ from sqlalchemy.dialects.postgresql import UUID from horde import exceptions as e -from horde import horde_redis as hr +from horde.bridge_reference import ( + is_backed_validated, +) from horde.classes.base.worker import Worker from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.model_reference import model_reference from horde.utils import sanitize_string @@ -46,7 +49,7 @@ def check_in(self, max_length, max_context_length, softprompts, **kwargs): super().check_in(**kwargs) self.max_length = max_length self.max_context_length = max_context_length - self.set_softprompts(softprompts) + self.set_softprompts(softprompts) # Does a commit as well paused_string = "" if self.paused: paused_string = "(Paused) " @@ -54,6 +57,7 @@ def check_in(self, max_length, max_context_length, softprompts, **kwargs): f"{paused_string}Text Worker {self.name} checked-in, offering models {self.models} " f"at {self.max_length} max tokens and {self.max_context_length} max content length.", ) + db.session.commit() def refresh_softprompt_cache(self): softprompts_list = [s.softprompt for s in self.softprompts] @@ -97,11 +101,10 @@ def set_softprompts(self, softprompts): ], ) db.session.query(TextWorkerSoftprompts).filter_by(worker_id=self.id).delete() - db.session.commit() + db.session.flush() for softprompt_name in softprompts: softprompt = TextWorkerSoftprompts(worker_id=self.id, softprompt=softprompt_name) db.session.add(softprompt) - db.session.commit() self.refresh_softprompt_cache() def calculate_uptime_reward(self): @@ -114,6 +117,9 @@ def calculate_uptime_reward(self): param_multiplier = model_reference.get_text_model_multiplier(model) / 7 if param_multiplier < 0.25: param_multiplier = 0.25 + # Unvalidated backends get less kudos + if not is_backed_validated(self.bridge_agent): + base_kudos *= 0.3 # The uptime is based on both how much context they provide, as well as how many parameters they're serving return round(base_kudos * param_multiplier, 2) @@ -125,6 +131,8 @@ def can_generate(self, waiting_prompt): return [False, "max_context_length"] if self.max_length < waiting_prompt.max_length: return [False, "max_length"] + if waiting_prompt.validated_backends and not is_backed_validated(self.bridge_agent): + return [False, "bridge_version"] matching_softprompt = True if waiting_prompt.softprompt: matching_softprompt = False diff --git a/horde/classes/stable/interrogation.py b/horde/classes/stable/interrogation.py index c0b4c083..62a84e9b 100644 --- a/horde/classes/stable/interrogation.py +++ b/horde/classes/stable/interrogation.py @@ -9,10 +9,10 @@ from sqlalchemy import JSON, Enum from sqlalchemy.dialects.postgresql import JSONB, UUID -from horde import horde_redis as hr from horde.consts import KNOWN_POST_PROCESSORS from horde.enums import State from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.r2 import generate_procgen_download_url, generate_procgen_upload_url from horde.utils import get_db_uuid, get_expiry_date, get_interrogation_form_expiry_date diff --git a/horde/classes/stable/processing_generation.py b/horde/classes/stable/processing_generation.py index efc70c5f..9b94d6f8 100644 --- a/horde/classes/stable/processing_generation.py +++ b/horde/classes/stable/processing_generation.py @@ -16,7 +16,6 @@ download_procgen_image, generate_procgen_download_url, upload_generated_image, - upload_prompt, upload_shared_generated_image, upload_shared_metadata, ) @@ -64,6 +63,9 @@ def get_gen_kudos(self): if self.wp.params.get("hires_fix", False): return self.wp.kudos * 7 return self.wp.kudos * 4 + if model_reference.get_model_baseline(self.model) in ["flux_1"]: + # Flux is double the size of SDXL and much slower, so it gives double the rewards from it. + return self.wp.kudos * 8 return self.wp.kudos def log_aborted_generation(self): @@ -75,21 +77,31 @@ def log_aborted_generation(self): ) def set_generation(self, generation, things_per_sec, **kwargs): - if kwargs.get("censored", False): - self.censored = True state = kwargs.get("state", "ok") - if state in ["censored", "csam"]: + censored = False + gen_metadata = kwargs.get("gen_metadata") if kwargs.get("gen_metadata") is not None else [] + for metadata in gen_metadata: + if metadata.get("type") != "censorship": + # this metadata isnt about censorship + continue + if metadata.get("value") == "csam": + censored = "csam" + else: + censored = "nsfw" + if censored is not False: self.censored = True db.session.commit() - if state == "csam": - prompt_dict = { - "prompt": self.wp.prompt, - "user": self.wp.user.get_unique_alias(), - "type": "clip", - } - upload_prompt(prompt_dict) + # Disabled prompt gathering for now + # if censored == "csam": + # prompt_dict = { + # "prompt": self.wp.prompt, + # "user": self.wp.user.get_unique_alias(), + # "type": "clip", + # } + # upload_prompt(prompt_dict) elif state == "faulted": - self.wp.n += 1 + if self.wp.count_finished_jobs() < self.wp.jobs: + self.wp.n += 1 self.abort() if self.is_completed(): return 0 @@ -116,7 +128,7 @@ def set_generation(self, generation, things_per_sec, **kwargs): record_image_statistic(self) if self.wp.shared and not self.fake and generation == "R2": self.upload_generation_metadata() - if state == "csam": + if censored == "csam": self.wp.user.record_problem_job( procgen=self, ipaddr=self.wp.ipaddr, @@ -140,3 +152,24 @@ def upload_generation_metadata(self): f.write(json_object) upload_shared_metadata(filename) os.remove(filename) + + def set_job_ttl(self): + # We are aiming here for a graceful min 2sec/it speed on workers for 512x512 which is well below our requested min 0.5mps/s, + # to buffer for model loading and allow for the occasional slowdown without dropping jobs. + # There is also a minimum of 2mins, regardless of steps and resolution used and an extra 30 seconds for model loading. + # This means a worker at 1mps/s should be able to finish a 512x512x50 request comfortably within 30s but we allow up to 2.5mins. + # This number then increases lineary based on the resolution requested. + # Using this formula, a 1536x768x40 request is expected to take ~50s on a 1mps/s worker, but we will only time out after 390s. + ttl_multiplier = (self.wp.width * self.wp.height) / (512 * 512) + self.job_ttl = 30 + (self.wp.get_accurate_steps() * 2 * ttl_multiplier) + # CN is 3 times slower + if self.wp.gen_payload.get("control_type"): + self.job_ttl = self.job_ttl * 2 + # Flux is way slower than Stable Diffusion + if any(model_reference.get_model_baseline(mn) in ["flux_1"] for mn in self.wp.get_model_names()): + self.job_ttl = self.job_ttl * 3 + if self.job_ttl < 150: + self.job_ttl = 150 + if self.worker.extra_slow_worker is True: + self.job_ttl = self.job_ttl * 3 + db.session.commit() diff --git a/horde/classes/stable/waiting_prompt.py b/horde/classes/stable/waiting_prompt.py index e82b6e9c..ac582bc2 100644 --- a/horde/classes/stable/waiting_prompt.py +++ b/horde/classes/stable/waiting_prompt.py @@ -13,6 +13,7 @@ from horde.classes.base.waiting_prompt import WaitingPrompt from horde.classes.stable.kudos import KudosModel from horde.consts import ( + BASELINE_BATCHING_MULTIPLIERS, HEAVY_POST_PROCESSORS, KNOWN_LCM_LORA_IDS, KNOWN_LCM_LORA_VERSIONS, @@ -124,8 +125,9 @@ def extract_params(self): self.trusted_workers = True self.shared = False self.prepare_job_payload(self.params) - self.set_job_ttl() # Commit will happen in prepare_job_payload() + # logger.debug(self.params) + # logger.debug(self.prompt) @logger.catch(reraise=True) def prepare_job_payload(self, initial_dict=None): @@ -197,6 +199,7 @@ def get_pop_payload(self, procgen_list, payload): "id": procgen.id, "model": procgen.model, "ids": [g.id for g in procgen_list], + "ttl": procgen_list[0].job_ttl, } if self.source_image and check_bridge_capability("img2img", procgen.worker.bridge_agent): if check_bridge_capability("r2_source", procgen.worker.bridge_agent): @@ -226,10 +229,10 @@ def get_pop_payload(self, procgen_list, payload): # logger.debug([payload,prompt_payload]) return prompt_payload - def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None): + def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None, kudos_adjustment=0): # We separate the activation from __init__ as often we want to check if there's a valid worker for it # Before we add it to the queue - super().activate(downgrade_wp_priority, extra_source_images=extra_source_images) + super().activate(downgrade_wp_priority, extra_source_images=extra_source_images, kudos_adjustment=kudos_adjustment) if source_image or source_mask: self.source_image = source_image self.source_mask = source_mask @@ -364,7 +367,7 @@ def require_upfront_kudos(self, counted_totals, total_threads): max_res = 768 # We allow everyone to use SDXL up to 1024 if max_res < 1024 and any( - model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names + model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade", "flux_1"] for mn in model_names ): max_res = 1024 if max_res > 1024: @@ -372,10 +375,8 @@ def require_upfront_kudos(self, counted_totals, total_threads): # Using more than 10 steps with LCM requires upfront kudos if self.is_using_lcm() and self.get_accurate_steps() > 10: return (True, max_res, False) - # Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse. - if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30: - return (True, max_res, False) - if self.get_accurate_steps() > 50: + # Some models don't require a lot of steps, so we check their requirements. The max steps we allow without upfront kudos is 40 + if any(model_reference.get_model_requirements(mn).get("max_steps", 40) < self.get_accurate_steps() for mn in model_names): return (True, max_res, False) if self.width * self.height > max_res * max_res: return (True, max_res, False) @@ -400,9 +401,7 @@ def downgrade(self, max_resolution): # Break, just in case we went too low if self.width * self.height < 512 * 512: break - max_steps = 50 - if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in self.get_model_names()): - max_steps = 30 + max_steps = min(model_reference.get_model_requirements(mn).get("max_steps", 30) for mn in self.get_model_names()) if self.params.get("control_type"): max_steps = 20 if self.is_using_lcm(): @@ -435,7 +434,7 @@ def get_accurate_steps(self): if self.params.get("sampler_name", "k_euler_a") in ["k_dpm_adaptive"]: # This sampler chooses the steps amount automatically # and disregards the steps value from the user - # so we just calculate it as an average 50 steps + # so we just calculate it as an average 40 steps return 40 steps = self.params["steps"] if self.params.get("sampler_name", "k_euler_a") in SECOND_ORDER_SAMPLERS: @@ -444,32 +443,6 @@ def get_accurate_steps(self): steps *= 2 return steps - def set_job_ttl(self): - # default is 2 minutes. Then we scale up based on resolution. - # This will be more accurate with a newer formula - self.job_ttl = 120 - if self.width * self.height > 2048 * 2048: - self.job_ttl = 800 - elif self.width * self.height > 1024 * 1024: - self.job_ttl = 400 - elif self.width * self.height > 728 * 728: - self.job_ttl = 260 - elif self.width * self.height >= 512 * 512: - self.job_ttl = 150 - # When too many steps are involved, we increase the expiry time - if self.get_accurate_steps() >= 200: - self.job_ttl = self.job_ttl * 3 - elif self.get_accurate_steps() >= 100: - self.job_ttl = self.job_ttl * 2 - # CN is 3 times slower - if self.gen_payload.get("control_type"): - self.job_ttl = self.job_ttl * 3 - if "SDXL_beta::stability.ai#6901" in self.get_model_names(): - logger.debug(self.get_model_names()) - self.job_ttl = 300 - # logger.info([weights_count,self.job_ttl]) - db.session.commit() - def log_faulted_prompt(self): source_processing = "txt2img" if self.source_image: @@ -499,6 +472,8 @@ def extrapolate_dry_run_kudos(self): return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1 if model_reference.get_model_baseline(model_name) in ["stable_cascade"]: return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1 + if model_reference.get_model_baseline(model_name) in ["flux_1"]: + return (self.calculate_extra_kudos_burn(kudos) * self.n * 8) + 1 # The +1 is the extra kudos burn per request return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1 @@ -513,5 +488,12 @@ def has_heavy_operations(self): return True return False + def get_highest_model_batching_multiplier(self): + highest_multiplier = 1 + for mn in self.get_model_names(): + if BASELINE_BATCHING_MULTIPLIERS.get(mn, 1) > highest_multiplier: + highest_multiplier = BASELINE_BATCHING_MULTIPLIERS.get(mn, 1) + return highest_multiplier + def count_pp(self): return len(self.params.get("post_processing", [])) diff --git a/horde/classes/stable/worker.py b/horde/classes/stable/worker.py index 72984492..514c5675 100644 --- a/horde/classes/stable/worker.py +++ b/horde/classes/stable/worker.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later + from horde import exceptions as e from horde.bridge_reference import ( check_bridge_capability, @@ -23,12 +24,13 @@ class ImageWorker(Worker): } # TODO: Switch to max_power max_pixels = db.Column(db.BigInteger, default=512 * 512, nullable=False) - allow_img2img = db.Column(db.Boolean, default=True, nullable=False) - allow_painting = db.Column(db.Boolean, default=True, nullable=False) - allow_post_processing = db.Column(db.Boolean, default=True, nullable=False) - allow_controlnet = db.Column(db.Boolean, default=False, nullable=False) - allow_sdxl_controlnet = db.Column(db.Boolean, default=False, nullable=False) - allow_lora = db.Column(db.Boolean, default=False, nullable=False) + allow_img2img = db.Column(db.Boolean, default=True, nullable=False, index=True) + allow_painting = db.Column(db.Boolean, default=True, nullable=False, index=True) + allow_post_processing = db.Column(db.Boolean, default=True, nullable=False, index=True) + allow_controlnet = db.Column(db.Boolean, default=False, nullable=False, index=True) + allow_sdxl_controlnet = db.Column(db.Boolean, default=False, nullable=False, index=True) + allow_lora = db.Column(db.Boolean, default=False, nullable=False, index=True) + limit_max_steps = db.Column(db.Boolean, default=False, nullable=False, index=True) wtype = "image" def check_in(self, max_pixels, **kwargs): @@ -43,6 +45,7 @@ def check_in(self, max_pixels, **kwargs): self.allow_controlnet = kwargs.get("allow_controlnet", False) self.allow_sdxl_controlnet = kwargs.get("allow_sdxl_controlnet", False) self.allow_lora = kwargs.get("allow_lora", False) + self.limit_max_steps = kwargs.get("limit_max_steps", False) if len(self.get_model_names()) == 0: self.set_models(["stable_diffusion"]) paused_string = "" @@ -138,6 +141,11 @@ def can_generate(self, waiting_prompt): and not check_bridge_capability("stable_cascade_2pass", self.bridge_agent) ): return [False, "bridge_version"] + if "flux_1" in model_reference.get_all_model_baselines(self.get_model_names()) and not check_bridge_capability( + "flux", + self.bridge_agent, + ): + return [False, "bridge_version"] if waiting_prompt.params.get("clip_skip", 1) > 1 and not check_bridge_capability( "clip_skip", self.bridge_agent, @@ -150,6 +158,30 @@ def can_generate(self, waiting_prompt): return [False, "bridge_version"] if not waiting_prompt.safe_ip and not self.allow_unsafe_ipaddr: return [False, "unsafe_ip"] + if self.limit_max_steps: + if len(waiting_prompt.get_model_names()) > 1: + for mn in waiting_prompt.get_model_names(): + avg_steps = ( + int( + model_reference.get_model_requirements(mn).get("min_steps", 20) + + model_reference.get_model_requirements(mn).get("max_steps", 40), + ) + / 2 + ) + if waiting_prompt.get_accurate_steps() > avg_steps: + return [False, "step_count"] + else: + # If the request has an empty model list, we compare instead to the worker's model list + for mn in self.get_model_names(): + avg_steps = ( + int( + model_reference.get_model_requirements(mn).get("min_steps", 20) + + model_reference.get_model_requirements(mn).get("max_steps", 40), + ) + / 2 + ) + if waiting_prompt.get_accurate_steps() > avg_steps: + return [False, "step_count"] # We do not give untrusted workers anon or VPN generations, to avoid anything slipping by and spooking them. # logger.warning(datetime.utcnow()) if not self.user.trusted: # FIXME #noqa SIM102 @@ -225,6 +257,7 @@ def get_safe_amount(self, amount, wp): if wp.has_heavy_operations(): pp_multiplier *= 1.8 mps *= pp_multiplier + mps *= wp.get_highest_model_batching_multiplier() safe_amount = round(safe_generations / mps) if safe_amount > amount: safe_amount = amount diff --git a/horde/consts.py b/horde/consts.py index d1d8f106..81cfea97 100644 --- a/horde/consts.py +++ b/horde/consts.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later -HORDE_VERSION = "4.40.3 " +HORDE_VERSION = "4.44.3" +HORDE_API_VERSION = "2.5" WHITELISTED_SERVICE_IPS = { "212.227.227.178", # Turing Bot @@ -38,6 +39,13 @@ "4x_AnimeSharp" "CodeFormers", } +# These models are very large in VRAM, so we increase the calculated MPS +# used to figure out batches by a set multiplier to reduce how many images are batched +# at a time when these models are used. +BASELINE_BATCHING_MULTIPLIERS = { + "flux_1": 3, +} + KNOWN_SAMPLERS = { "k_lms", diff --git a/horde/countermeasures.py b/horde/countermeasures.py index 1fae3f91..0f596ea0 100644 --- a/horde/countermeasures.py +++ b/horde/countermeasures.py @@ -19,26 +19,16 @@ ) ip_r = None -logger.init("IP Address Cache", status="Connecting") -if is_redis_up(): - ip_r = get_ipaddr_db() - logger.init_ok("IP Address Cache", status="Connected") -else: - logger.init_err("IP Address Cache", status="Failed") ip_s_r = None -logger.init("IP Suspicion Cache", status="Connecting") -if is_redis_up(): - ip_s_r = get_ipaddr_suspicion_db() - logger.init_ok("IP Suspicion Cache", status="Connected") -else: - logger.init_err("IP Suspicion Cache", status="Failed") ip_t_r = None -logger.init("IP Timeout Cache", status="Connecting") +logger.init("IP Caches", status="Connecting") if is_redis_up(): + ip_r = get_ipaddr_db() + ip_s_r = get_ipaddr_suspicion_db() ip_t_r = get_ipaddr_timeout_db() - logger.init_ok("IP Timeout Cache", status="Connected") + logger.init_ok("IP Caches", status="Connected") else: - logger.init_err("IP Timeout Cache", status="Failed") + logger.init_err("IP Caches", status="Failed") test_timeout = 0 diff --git a/horde/data/news.json b/horde/data/news.json index 9cc52794..95672da8 100644 --- a/horde/data/news.json +++ b/horde/data/news.json @@ -1,4 +1,20 @@ [ + { + "date_published": "2024-09-30", + "newspiece": "[Flux.1-Schnell](https://blackforestlabs.ai/) is now available on the AI Horde! Remember to adjust your steps and cfg when using it!", + "tags": [ + "nlnet", + "tazlin", + "db0", + "text2img", + "flux" + ], + "importance": "Information", + "more_info_urls": [ + "https://blackforestlabs.ai/" + ], + "title": "Flux.1-Schnell now available on the AI Horde." + }, { "date_published": "2024-08-07", "newspiece": "ArtBot is now an official component of Haidra! You can check the official version at [https://artbot.site](https://artbot.site which is running the latest version of it with a ton of improvements. All kudos to [Rockbandit](https://mastodon.world/@davely)", diff --git a/horde/database/functions.py b/horde/database/functions.py index 421a9ad2..3f5c0a39 100644 --- a/horde/database/functions.py +++ b/horde/database/functions.py @@ -13,13 +13,13 @@ from sqlalchemy.orm import noload import horde.classes.base.stats as stats -from horde import horde_redis as hr from horde import vars as hv from horde.bridge_reference import ( check_bridge_capability, get_supported_samplers, ) from horde.classes.base.detection import Filter +from horde.classes.base.style import Style, StyleCollection, StyleModel, StyleTag from horde.classes.base.user import KudosTransferLog, User, UserRecords, UserSharedKey from horde.classes.base.waiting_prompt import WPAllowedWorkers, WPModels from horde.classes.base.worker import WorkerModel, WorkerPerformance @@ -34,6 +34,7 @@ from horde.database.classes import FakeWPRow from horde.enums import State from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.model_reference import model_reference from horde.utils import hash_api_key, validate_regex @@ -218,6 +219,13 @@ def find_worker_by_name(worker_name, worker_class=ImageWorker): return worker +def find_worker_id_by_name(worker_name): + for worker_class in [ImageWorker, TextWorker, InterrogationWorker]: + worker_id = db.session.query(worker_class.id).filter_by(name=worker_name).first() + if worker_id: + return worker_id + + def worker_name_exists(worker_name): for worker_class in [ImageWorker, TextWorker, InterrogationWorker]: worker = db.session.query(worker_class).filter_by(name=worker_name).count() @@ -761,18 +769,17 @@ def count_things_for_specific_model(wp_class, procgen_class, model_name): return things, jobs +@logger.catch(reraise=True) def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, priority_user_ids=None, page=0): - # This is just the top 25 - Adjusted method to send ImageWorker object. Filters to add. + # This is just the top 3 - Adjusted method to send ImageWorker object. Filters to add. # TODO: Filter by ImageWorker not in WP.tricked_worker # TODO: If any word in the prompt is in the WP.blacklist rows, then exclude it (L293 in base.worker.ImageWorker.gan_generate()) PER_PAGE = 3 # how many requests we're picking up to filter further final_wp_list = ( db.session.query(ImageWaitingPrompt) .options(noload(ImageWaitingPrompt.processing_gens)) - .outerjoin( - WPModels, - WPAllowedWorkers, - ) + .outerjoin(WPModels, ImageWaitingPrompt.id == WPModels.wp_id) + .outerjoin(WPAllowedWorkers, ImageWaitingPrompt.id == WPAllowedWorkers.wp_id) .filter( ImageWaitingPrompt.n > 0, ImageWaitingPrompt.active == True, # noqa E712 @@ -840,6 +847,13 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p worker.speed >= 500000, # 0.5 MPS/s ImageWaitingPrompt.slow_workers == True, # noqa E712 ), + or_( + worker.extra_slow_worker is False, + and_( + worker.extra_slow_worker is True, + ImageWaitingPrompt.extra_slow_workers.is_(True), + ), + ), or_( not_(ImageWaitingPrompt.params.has_key("transparent")), ImageWaitingPrompt.params["transparent"].astext.cast(Boolean).is_(False), @@ -921,10 +935,8 @@ def count_skipped_image_wp(worker, models_list=None, blacklist=None, priority_us open_wp_list = ( db.session.query(ImageWaitingPrompt) .options(noload(ImageWaitingPrompt.processing_gens)) - .outerjoin( - WPModels, - WPAllowedWorkers, - ) + .outerjoin(WPModels, ImageWaitingPrompt.id == WPModels.wp_id) + .outerjoin(WPAllowedWorkers, ImageWaitingPrompt.id == WPAllowedWorkers.wp_id) .filter( ImageWaitingPrompt.n > 0, ImageWaitingPrompt.active == True, # noqa E712 @@ -1047,6 +1059,12 @@ def count_skipped_image_wp(worker, models_list=None, blacklist=None, priority_us ).count() if skipped_wps > 0: ret_dict["performance"] = skipped_wps + if worker.extra_slow_worker is True: + skipped_wps = open_wp_list.filter( + ImageWaitingPrompt.extra_slow_workers == False, # noqa E712 + ).count() + if skipped_wps > 0: + ret_dict["performance"] = ret_dict.get("performance", 0) + skipped_wps # Count skipped WPs requiring trusted workers if worker.user.trusted is False: skipped_wps = open_wp_list.filter( @@ -1505,3 +1523,106 @@ def retrieve_regex_replacements(filter_type): def get_all_users(sort="kudos", offset=0): user_order_by = User.created.asc() if sort == "age" else User.kudos.desc() return db.session.query(User).order_by(user_order_by).offset(offset).limit(25).all() + + +def get_style_by_uuid(style_uuid: str, is_collection=None): + try: + style_uuid = uuid.UUID(style_uuid) + except ValueError: + return None + if SQLITE_MODE: + style_uuid = str(style_uuid) + style = None + if is_collection is not True: + style = db.session.query(Style).filter_by(id=style_uuid).first() + if is_collection is True or not style: + collection = db.session.query(StyleCollection).filter_by(id=style_uuid).first() + return collection + else: + return style + + +def get_style_by_name(style_name: str, is_collection=None): + """Goes through the styles and the categories and attempts to find a + style or category that matches the given name + The user can pre-specify a filter for category or style and/or username + by formatting the name like + category::db0#1::my_stylename + alternatively this format is also allowed to allow multiple users to use the same name + style::my_stylename + db0#1::my_stylename + """ + style_split = style_name.split("::") + user = None + # We don't change the is_collection if it comes preset in kwargs, as we then want it explicitly to return none + # When searching for styles in collections and vice-versa + if len(style_split) == 3: + style_name = style_split[2] + if is_collection is None: + if style_split[0] == "collection": + is_collection = True + elif style_split[0] == "style": + is_collection = False + user = find_user_by_username(style_split[1]) + if len(style_split) == 2: + style_name = style_split[1] + if style_split[0] == "collection": + if is_collection is None: + is_collection = True + elif style_split[0] == "style": + if is_collection is None: + is_collection = False + else: + user = find_user_by_username(style_split[0]) + seek_classes = [Style, StyleCollection] + if is_collection is True: + seek_classes = [StyleCollection] + elif is_collection is False: + seek_classes = [Style] + for class_seek in seek_classes: + style_query = db.session.query(class_seek).filter_by(name=style_name) + if user is not None: + style_query = style_query.filter_by(user_id=user.id) + style = style_query.first() + if style: + return style + + +def retrieve_available_styles( + style_type=None, + sort="popular", + public_only=True, + page=0, + tag=None, + model=None, +): + """Retrieves all style details from DB.""" + style_query = db.session.query(Style).filter_by(style_type=style_type) + if tag is not None: + style_query = style_query.join(StyleTag) + if model is not None: + style_query = style_query.join(StyleModel) + if public_only: + style_query = style_query.filter(Style.public.is_(True)) + if tag is not None: + style_query = style_query.filter(StyleTag.tag == tag) + if model is not None: + style_query = style_query.filter(StyleModel.model == model) + style_order_by = Style.created.asc() if sort == "age" else Style.use_count.desc() + return style_query.order_by(style_order_by).offset(page).limit(25).all() + + +def retrieve_available_collections( + collection_type=None, + sort="popular", + public_only=True, + page=0, +): + """Retrieves all collection details from DB.""" + style_query = db.session.query(StyleCollection) + if collection_type is not None: + style_query = style_query.filter_by(style_type=collection_type) + if public_only: + style_query = style_query.filter(StyleCollection.public.is_(True)) + style_order_by = StyleCollection.created.asc() if sort == "age" else StyleCollection.use_count.desc() + return style_query.order_by(style_order_by).offset(page).limit(25).all() diff --git a/horde/database/text_functions.py b/horde/database/text_functions.py index 5c033289..7d2f817b 100644 --- a/horde/database/text_functions.py +++ b/horde/database/text_functions.py @@ -10,7 +10,9 @@ from sqlalchemy.orm import noload import horde.classes.base.stats as stats -from horde import horde_redis as hr +from horde.bridge_reference import ( + is_backed_validated, +) from horde.classes.base.waiting_prompt import WPAllowedWorkers, WPModels from horde.classes.base.worker import WorkerPerformance from horde.classes.kobold.processing_generation import TextProcessingGeneration @@ -19,6 +21,7 @@ from horde.classes.kobold.waiting_prompt import TextWaitingPrompt from horde.database.functions import query_prioritized_wps from horde.flask import SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.model_reference import model_reference @@ -30,7 +33,7 @@ def convert_things_to_kudos(things, **kwargs): def get_sorted_text_wp_filtered_to_worker(worker, models_list=None, priority_user_ids=None, page=0): - # This is just the top 100 - Adjusted method to send Worker object. Filters to add. + # This is just the top 3 - Adjusted method to send Worker object. Filters to add. # TODO: Filter by (Worker in WP.workers) __ONLY IF__ len(WP.workers) >=1 # TODO: Filter by WP.trusted_workers == False __ONLY IF__ Worker.user.trusted == False # TODO: Filter by Worker not in WP.tricked_worker @@ -49,10 +52,8 @@ def get_sorted_text_wp_filtered_to_worker(worker, models_list=None, priority_use final_wp_list = ( db.session.query(TextWaitingPrompt) .options(noload(TextWaitingPrompt.processing_gens)) - .outerjoin( - WPModels, - WPAllowedWorkers, - ) + .outerjoin(WPModels, TextWaitingPrompt.id == WPModels.wp_id) + .outerjoin(WPAllowedWorkers, TextWaitingPrompt.id == WPAllowedWorkers.wp_id) .filter( TextWaitingPrompt.n > 0, TextWaitingPrompt.max_length <= worker.max_length, @@ -91,6 +92,10 @@ def get_sorted_text_wp_filtered_to_worker(worker, models_list=None, priority_use worker.maintenance == False, # noqa E712 TextWaitingPrompt.user_id == worker.user_id, ), + or_( + is_backed_validated(worker.bridge_agent), + TextWaitingPrompt.validated_backends.is_(False), + ), ) ) if priority_user_ids: diff --git a/horde/database/threads.py b/horde/database/threads.py index 5c78c1b9..b01c8fde 100644 --- a/horde/database/threads.py +++ b/horde/database/threads.py @@ -9,7 +9,6 @@ import patreon from sqlalchemy import func, or_ -from horde import horde_redis as hr from horde.argparser import args from horde.classes.base.user import User from horde.classes.kobold.processing_generation import TextProcessingGeneration @@ -30,6 +29,7 @@ ) from horde.enums import State from horde.flask import HORDE, SQLITE_MODE, db +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.patreon import patrons from horde.r2 import delete_source_image @@ -204,17 +204,18 @@ def check_waiting_prompts(): .filter( procgen_class.generation == None, # noqa E712 procgen_class.faulted == False, # noqa E712 - # cutoff_time - procgen_class.start_time > wp_class.job_ttl, - # How do we calculate this in the query? Maybe I need to - # set an expiry time iun procgen as well better? + # TODO: How do we calculate this in the query? + # cutoff_time - procgen_class.start_time > procgen_class.job_ttl, ) .all() ) + modifed_procgens = 0 for proc_gen in all_proc_gen: - if proc_gen.is_stale(proc_gen.wp.job_ttl): + if proc_gen.is_stale(): proc_gen.abort() proc_gen.wp.n += 1 - if len(all_proc_gen) >= 1: + modifed_procgens += 1 + if modifed_procgens >= 1: db.session.commit() # Faults WP with 3 or more faulted Procgens wp_ids = ( diff --git a/horde/detection.py b/horde/detection.py index 5f78e971..f9025469 100644 --- a/horde/detection.py +++ b/horde/detection.py @@ -13,7 +13,7 @@ from horde.argparser import args from horde.database.functions import compile_regex_filter, retrieve_regex_replacements from horde.flask import HORDE, SQLITE_MODE # Local Testing -from horde.horde_redis import horde_r_get +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.model_reference import model_reference @@ -82,7 +82,7 @@ def refresh_regex(self): with HORDE.app_context(): stored_replacements = retrieve_regex_replacements(filter_type=10) else: - cached_replacements = horde_r_get("cached_regex_replacements") + cached_replacements = hr.horde_r_get("cached_regex_replacements") if not cached_replacements: logger.warning("No cached regex replacements found in redis! Check threads!") stored_replacements = [] @@ -97,7 +97,7 @@ def refresh_regex(self): with HORDE.app_context(): stored_filter = compile_regex_filter(_id) else: - stored_filter = horde_r_get(filter_id) + stored_filter = hr.horde_r_get(filter_id) # Ensure we don't get catch-all regex if not stored_filter: continue diff --git a/horde/enums.py b/horde/enums.py index ba14f09d..914870b2 100644 --- a/horde/enums.py +++ b/horde/enums.py @@ -26,6 +26,7 @@ class UserRecordTypes(enum.Enum): USAGE = 1 FULFILLMENT = 3 REQUEST = 4 + STYLE = 5 class UserRoleTypes(enum.Enum): diff --git a/horde/exceptions.py b/horde/exceptions.py index 0f107dae..a95bcc9d 100644 --- a/horde/exceptions.py +++ b/horde/exceptions.py @@ -150,6 +150,14 @@ "InvalidTransparencyModel", "InvalidTransparencyImg2Img", "InvalidTransparencyCN", + "HiResMismatch", + "StylesAnonForbidden", + "StylePromptMissingVars", + "StylesRequiresCustomizer", + "StyleMismatch", + "StyleGetMistmatch", + "TooManyStyleExamples", + "ExampleURLAlreadyInUse", ] @@ -216,9 +224,9 @@ def __init__(self, username, rc="InvalidSize"): class InvalidPromptSize(wze.BadRequest): - def __init__(self, username, rc="InvalidPromptSize"): + def __init__(self, rc="InvalidPromptSize"): self.specific = "Too large prompt. Please reduce the amount of tokens contained." - self.log = f"User '{username}' sent an invalid size. Aborting!" + self.log = None self.rc = rc diff --git a/horde/flask.py b/horde/flask.py index fd975323..c8efa737 100644 --- a/horde/flask.py +++ b/horde/flask.py @@ -13,32 +13,39 @@ from horde.redis_ctrl import ger_cache_url, is_redis_up cache = None -HORDE = Flask(__name__) -HORDE.config.SWAGGER_UI_DOC_EXPANSION = "list" -HORDE.wsgi_app = ProxyFix(HORDE.wsgi_app, x_for=1) - SQLITE_MODE = os.getenv("USE_SQLITE", "0") == "1" -if SQLITE_MODE: - logger.warning("Using SQLite for database") - HORDE.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///horde.db" -else: - HORDE.config["SQLALCHEMY_DATABASE_URI"] = ( - f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:" f"{os.getenv('POSTGRES_PASS')}@{os.getenv('POSTGRES_URL')}" - ) - HORDE.config["SQLALCHEMY_ENGINE_OPTIONS"] = { - "pool_size": 50, - "max_overflow": -1, - # "pool_pre_ping": True, - } -HORDE.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False -db = SQLAlchemy(HORDE) -db.init_app(HORDE) - -if not SQLITE_MODE: - with HORDE.app_context(): - logger.warning(f"pool size = {db.engine.pool.size()}") -logger.init_ok("Horde Database", status="Started") + +def create_app(): + HORDE = Flask(__name__) + HORDE.config.SWAGGER_UI_DOC_EXPANSION = "list" + HORDE.wsgi_app = ProxyFix(HORDE.wsgi_app, x_for=1) + + if SQLITE_MODE: + logger.warning("Using SQLite for database") + HORDE.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///horde.db" + else: + HORDE.config["SQLALCHEMY_DATABASE_URI"] = ( + f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:" f"{os.getenv('POSTGRES_PASS')}@{os.getenv('POSTGRES_URL')}" + ) + HORDE.config["SQLALCHEMY_ENGINE_OPTIONS"] = { + "pool_size": 50, + "max_overflow": -1, + # "pool_pre_ping": True, + } + HORDE.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + db.init_app(HORDE) + + if not SQLITE_MODE: + with HORDE.app_context(): + logger.warning(f"pool size = {db.engine.pool.size()}") + logger.init_ok("Horde Database", status="Started") + + return HORDE + + +db = SQLAlchemy() +HORDE = create_app() if is_redis_up(): try: diff --git a/horde/horde_redis.py b/horde/horde_redis.py index 636ede9c..3d1ddf3a 100644 --- a/horde/horde_redis.py +++ b/horde/horde_redis.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: AGPL-3.0-or-later import json +import threading +import time from datetime import timedelta from threading import Lock @@ -15,121 +17,125 @@ is_redis_up, ) -locks = {} - -horde_r = None -all_horde_redis = [] -logger.init("Horde Redis", status="Connecting") -if is_redis_up(): - horde_r = get_horde_db() - all_horde_redis = get_all_redis_db_servers() - logger.init_ok("Horde Redis", status="Connected") -else: - logger.init_err("Horde Redis", status="Failed") - - -horde_local_r = None -logger.init("Horde Local Redis", status="Connecting") -if is_local_redis_up(): - horde_local_r = get_local_horde_db() - logger.init_ok("Horde Local Redis", status="Connected") -else: - logger.init_err("Horde Local Redis", status="Failed") - - -def horde_r_set(key, value): - for hr in all_horde_redis: - try: - hr.set(key, value) - except Exception as err: - logger.warning(f"Exception when writing in redis servers {hr}: {err}") - if horde_local_r: - horde_local_r.setex(key, timedelta(10), value) - - -def horde_r_setex(key, expiry, value): - for hr in all_horde_redis: - try: - hr.setex(key, expiry, value) - except Exception as err: - logger.warning(f"Exception when writing in redis servers {hr}: {err}") - # We don't keep local cache for more than 5 seconds - if expiry > timedelta(5): - expiry = timedelta(5) - if horde_local_r: - horde_local_r.setex(key, expiry, value) - - -def horde_r_setex_json(key, expiry, value): - """Same as horde_r_setex() - but also converts the python builtin value to json - """ - horde_r_setex(key, expiry, json.dumps(value)) - - -def horde_r_local_set_to_json(key, value): - if horde_local_r: - if key not in locks: - locks[key] = Lock() - locks[key].acquire() - try: - horde_local_r.set(key, json.dumps(value)) - except Exception as err: - logger.error(f"Something went wrong when setting local redis: {err}") - locks[key].release() - - -def horde_local_setex_to_json(key, seconds, value): - if horde_local_r: - if key not in locks: - locks[key] = Lock() - locks[key].acquire() - try: - horde_local_r.setex(key, timedelta(seconds=seconds), json.dumps(value)) - except Exception as err: - logger.error(f"Something went wrong when setting local redis: {err}") - locks[key].release() - - -def horde_r_get(key): - """Retrieves the value from local redis if it exists - If it doesn't exist retrieves it from remote redis - If it exists in remote redis, also stores it in local redis - """ - value = None - if horde_local_r: - # if key in ["worker_cache","worker_cache_privileged"]: - # logger.warning(f"Got {key} from Local") - value = horde_local_r.get(key) - if value is None and horde_r: - value = horde_r.get(key) - if value is not None and horde_local_r is not None: - ttl = horde_r.ttl(key) - if ttl > 5: - ttl = 5 - if ttl <= 0: - ttl = 2 - # The local redis cache is always very temporary - if value is not None: - horde_local_r.setex(key, timedelta(seconds=abs(ttl)), value) - return value - - -def horde_r_get_json(key): - """Same as horde_r_get() - but also converts the json to python built-ins - """ - value = horde_r_get(key) - if value is None: - return None - return json.loads(value) - - -def horde_r_delete(key): - for hr in all_horde_redis: - try: - hr.delete(key) - except Exception as err: - logger.warning(f"Exception when deleting from redis servers {hr}: {err}") - if horde_local_r: - horde_local_r.delete(key) + +class HordeRedis: + locks = {} + horde_r = None + all_horde_redis = [] + horde_local_r = None + check_redis_thread = None + + def __init__(self): + logger.init("Horde Redis", status="Connecting") + if is_redis_up(): + self.horde_r = get_horde_db() + self.all_horde_redis = get_all_redis_db_servers() + logger.init_ok("Horde Redis", status="Connected") + else: + logger.init_err("Horde Redis", status="Failed") + logger.init("Horde Local Redis", status="Connecting") + if is_local_redis_up(): + self.horde_local_r = get_local_horde_db() + logger.init_ok("Horde Local Redis", status="Connected") + else: + logger.init_err("Horde Local Redis", status="Failed") + self.check_redis_thread = threading.Thread(target=self.check_redis_backends, args=(), daemon=True) + self.check_redis_thread.start() + + def check_redis_backends(self): + while True: + time.sleep(10) + self.all_horde_redis = get_all_redis_db_servers() + + def horde_r_set(self, key, value): + for hr in self.all_horde_redis: + try: + hr.set(key, value) + except Exception as err: + logger.warning(f"Exception when writing in redis servers {hr}: {err}") + if self.horde_local_r: + self.horde_local_r.setex(key, timedelta(10), value) + + def horde_r_setex(self, key, expiry, value): + for hr in self.all_horde_redis: + try: + hr.setex(key, expiry, value) + except Exception as err: + logger.warning(f"Exception when writing in redis servers {hr}: {err}") + # We don't keep local cache for more than 5 seconds + if expiry > timedelta(5): + expiry = timedelta(5) + if self.horde_local_r: + self.horde_local_r.setex(key, expiry, value) + + def horde_r_setex_json(self, key, expiry, value): + """Same as horde_r_setex() + but also converts the python builtin value to json + """ + self.horde_r_setex(key, expiry, json.dumps(value)) + + def horde_r_local_set_to_json(self, key, value): + if self.horde_local_r: + if key not in self.locks: + self.locks[key] = Lock() + self.locks[key].acquire() + try: + self.horde_local_r.set(key, json.dumps(value)) + except Exception as err: + logger.error(f"Something went wrong when setting local redis: {err}") + self.locks[key].release() + + def horde_local_setex_to_json(self, key, seconds, value): + if self.horde_local_r: + if key not in self.locks: + self.locks[key] = Lock() + self.locks[key].acquire() + try: + self.horde_local_r.setex(key, timedelta(seconds=seconds), json.dumps(value)) + except Exception as err: + logger.error(f"Something went wrong when setting local redis: {err}") + self.locks[key].release() + + def horde_r_get(self, key): + """Retrieves the value from local redis if it exists + If it doesn't exist retrieves it from remote redis + If it exists in remote redis, also stores it in local redis + """ + value = None + if self.horde_local_r: + # if key in ["worker_cache","worker_cache_privileged"]: + # logger.warning(f"Got {key} from Local") + value = self.horde_local_r.get(key) + if value is None and self.horde_r: + value = self.horde_r.get(key) + if value is not None and self.horde_local_r is not None: + ttl = self.horde_r.ttl(key) + if ttl > 5: + ttl = 5 + if ttl <= 0: + ttl = 2 + # The local redis cache is always very temporary + if value is not None: + self.horde_local_r.setex(key, timedelta(seconds=abs(ttl)), value) + return value + + def horde_r_get_json(self, key): + """Same as horde_r_get() + but also converts the json to python built-ins + """ + value = self.horde_r_get(key) + if value is None: + return None + return json.loads(value) + + def horde_r_delete(self, key): + for hr in self.all_horde_redis: + try: + hr.delete(key) + except Exception as err: + logger.warning(f"Exception when deleting from redis servers {hr}: {err}") + if self.horde_local_r: + self.horde_local_r.delete(key) + + +horde_redis = HordeRedis() diff --git a/horde/model_reference.py b/horde/model_reference.py index 79604913..1f0a4279 100644 --- a/horde/model_reference.py +++ b/horde/model_reference.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later import os +from datetime import datetime import requests @@ -24,15 +25,21 @@ class ModelReference(PrimaryTimedFunction): testing_models = {} def call_function(self): - """Retrieves to nataili and text model reference and stores in it a var""" + """Retrieves to image and text model reference and stores in it a var""" # If it's running in SQLITE_MODE, it means it's a test and we never want to grab the quorum # We don't want to report on any random model name a client might request for _riter in range(10): try: + ref_json = "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json" + if datetime.utcnow() <= datetime(2024, 9, 30): # Flux Beta + ref_json = ( + "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/refs/heads/flux/stable_diffusion.json" + ) + logger.debug("Using flux beta model reference...") self.reference = requests.get( os.getenv( "HORDE_IMAGE_COMPVIS_REFERENCE", - "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json", + ref_json, ), timeout=2, ).json() @@ -53,6 +60,7 @@ def call_function(self): "stable diffusion 2 512", "stable_diffusion_xl", "stable_cascade", + "flux_1", }: self.stable_diffusion_names.add(model) if self.reference[model].get("nsfw"): diff --git a/horde/patreon.py b/horde/patreon.py index aa6ab02d..cc93f770 100644 --- a/horde/patreon.py +++ b/horde/patreon.py @@ -4,7 +4,7 @@ import json -from horde import horde_redis as hr +from horde.horde_redis import horde_redis as hr from horde.logger import logger from horde.threads import PrimaryTimedFunction diff --git a/horde/redis_ctrl.py b/horde/redis_ctrl.py index d9f1e5f6..e684dcd0 100644 --- a/horde/redis_ctrl.py +++ b/horde/redis_ctrl.py @@ -22,9 +22,17 @@ ipaddr_timeout_db = 5 -def is_redis_up() -> bool: +def is_redis_up(hostname=redis_hostname, port=redis_port) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex((redis_hostname, redis_port)) == 0 + s.settimeout(3) + try: + return s.connect_ex((hostname, port)) == 0 + except socket.gaierror as e: + # connect_ex suppresses exceptions from POSIX connect() call + # but can still raise gaierror if e.g. the hostname is invalid. + # This may be transient, so log the error and return False. + logger.error(f"Redis server at {hostname}:{port} is not reachable: {e}") + return False def is_local_redis_up() -> bool: @@ -70,7 +78,13 @@ def get_all_redis_db_servers(): This allows redis to transparently failover. """ try: - return [get_redis_db_server(rs) for rs in json.loads(os.getenv("REDIS_SERVERS"))] + working_redis = [] + for rs in json.loads(os.getenv("REDIS_SERVERS")): + if is_redis_up(rs): + working_redis.append(get_redis_db_server(rs)) + else: + logger.warning(f"redis server '{rs} appears unreachable. Will not be used set in the cluster") + return working_redis except Exception: logger.error("Error setting up REDIS_SERVERS array. Falling back to loadbalancer.") return [get_horde_db()] diff --git a/horde/routes.py b/horde/routes.py index 380ba027..91d0bcd2 100644 --- a/horde/routes.py +++ b/horde/routes.py @@ -20,6 +20,7 @@ from horde.classes.base import settings from horde.classes.base.news import News from horde.classes.base.user import User +from horde.consts import HORDE_API_VERSION, HORDE_VERSION from horde.countermeasures import CounterMeasures from horde.database import functions as database from horde.flask import HORDE, cache, db @@ -29,6 +30,8 @@ from horde.vars import ( google_verification_string, horde_contact_email, + horde_logo, + horde_repository, horde_title, horde_url, img_url, @@ -391,3 +394,26 @@ def terms(): @HORDE.route("/assets/") def assets(filename): return send_from_directory("../assets", filename) + + +@HORDE.route("/.well-known/serviceinfo") +def serviceinfo(): + return { + "version": "0.2", + "software": { + "name": horde_title, + "version": HORDE_VERSION, + "repository": horde_repository, + "homepage": horde_url, + "logo": horde_logo, + }, + "api": { + "aihorde": { + "name": "AI Horde API", + "version": HORDE_API_VERSION, + "base_url": f"{horde_url}/api/v2", + "rel_url": "/api/v2", + "documentation": f"{horde_url}/api", + }, + }, + }, 200 diff --git a/horde/sandbox.py b/horde/sandbox.py index ec19e306..3e711a81 100644 --- a/horde/sandbox.py +++ b/horde/sandbox.py @@ -6,13 +6,14 @@ import sys import horde.classes.base.stats as stats +from horde.classes.base.style import Style, StyleCollection, StyleModel, StyleTag from horde.classes.stable.worker import ImageWorker from horde.countermeasures import CounterMeasures from horde.database import functions as database from horde.database import threads as threads from horde.detection import prompt_checker from horde.discord import send_pause_notification, send_problem_user_notification -from horde.flask import HORDE +from horde.flask import HORDE, db from horde.logger import logger from horde.model_reference import model_reference from horde.patreon import patrons @@ -60,4 +61,9 @@ def test(): # logger.debug(w.max_context_length) # logger.debug(w.calculate_uptime_reward()) pass + + with HORDE.app_context(): + logger.debug(db.session.query(StyleCollection).offset(0).limit(25).all()) + logger.debug(database.retrieve_available_collections()) + sys.exit() diff --git a/horde/utils.py b/horde/utils.py index e4389b75..8e69840e 100644 --- a/horde/utils.py +++ b/horde/utils.py @@ -16,6 +16,7 @@ from better_profanity import profanity from profanity_check import predict +from horde import exceptions as e from horde.flask import SQLITE_MODE profanity.load_censor_words() @@ -101,6 +102,10 @@ def get_expiry_date(): return datetime.utcnow() + dateutil.relativedelta.relativedelta(minutes=+20) +def get_extra_slow_expiry_date(): + return datetime.utcnow() + dateutil.relativedelta.relativedelta(minutes=+60) + + def get_interrogation_form_expiry_date(): return datetime.utcnow() + dateutil.relativedelta.relativedelta(minutes=+3) @@ -135,3 +140,9 @@ def does_extra_text_reference_exist(extra_texts, reference): if et["reference"] == reference: return True return False + + +def ensure_clean(string, key): + if is_profane(string): + raise e.BadRequest(f"{key} contains profanity") + return sanitize_string(string) diff --git a/horde/validation.py b/horde/validation.py new file mode 100644 index 00000000..9492986c --- /dev/null +++ b/horde/validation.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: 2024 Konstantinos Thoukydidis +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +from loguru import logger + +from horde import exceptions as e +from horde.classes.base.user import User +from horde.consts import KNOWN_POST_PROCESSORS, KNOWN_UPSCALERS +from horde.enums import WarningMessage +from horde.model_reference import model_reference + + +class ParamValidator: + + prompt: str + models: list + params: dict + user: User + warnings = set() + + def __init__(self, prompt, models, params, user): + self.prompt = prompt + self.models = models + self.params = params + self.user = user + self.warnings = set() + + def validate_base_params(self): + pass + + def validate_text_params(self): + self.validate_base_params() + if self.params.get("max_context_length", 1024) < self.params.get("max_length", 80): + raise e.BadRequest("You cannot request more tokens than your context length.", rc="TokenOverflow") + if "sampler_order" in self.params and len(set(self.params["sampler_order"])) < 7: + raise e.BadRequest( + "When sending a custom sampler order, you need to specify all possible samplers in the order", + rc="MissingFullSamplerOrder", + ) + if "stop_sequence" in self.params: + stop_seqs = set(self.params["stop_sequence"]) + if len(stop_seqs) > 128: + raise e.BadRequest("Too many stop sequences specified (max allowed is 128).", rc="TooManyStopSequences") + total_stop_seq_len = 0 + for seq in stop_seqs: + total_stop_seq_len += len(seq) + if total_stop_seq_len > 2000: + raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).", rc="ExcessiveStopSequence") + + def validate_image_params(self): + self.validate_base_params() + for model_req_dict in [model_reference.get_model_requirements(m) for m in self.models]: + if "clip_skip" in model_req_dict and model_req_dict["clip_skip"] != self.params.get("clip_skip", 1): + self.warnings.add(WarningMessage.ClipSkipMismatch) + if "min_steps" in model_req_dict and model_req_dict["min_steps"] > self.params.get("steps", 30): + self.warnings.add(WarningMessage.StepsTooFew) + if "max_steps" in model_req_dict and model_req_dict["max_steps"] < self.params.get("steps", 30): + self.warnings.add(WarningMessage.StepsTooMany) + if "cfg_scale" in model_req_dict and model_req_dict["cfg_scale"] != self.params.get("cfg_scale", 7.5): + self.warnings.add(WarningMessage.CfgScaleMismatch) + if "min_cfg_scale" in model_req_dict and model_req_dict["min_cfg_scale"] > self.params.get("cfg_scale", 7.5): + self.warnings.add(WarningMessage.CfgScaleTooSmall) + if "max_cfg_scale" in model_req_dict and model_req_dict["max_cfg_scale"] < self.params.get("cfg_scale", 7.5): + self.warnings.add(WarningMessage.CfgScaleTooLarge) + if "samplers" in model_req_dict and self.params.get("sampler_name", "k_euler_a") not in model_req_dict["samplers"]: + self.warnings.add(WarningMessage.SamplerMismatch) + # FIXME: Scheduler workaround until we support multiple schedulers + scheduler = "karras" + if not self.params.get("karras", True): + scheduler = "simple" + if "schedulers" in model_req_dict and scheduler not in model_req_dict["schedulers"]: + self.warnings.add(WarningMessage.SchedulerMismatch) + if any(model_reference.get_model_baseline(model_name).startswith("flux_1") for model_name in self.models): + if self.params.get("hires_fix", False) is True: + raise e.BadRequest("HiRes Fix does not work with Flux currently.", rc="HiResMismatch") + if "loras" in self.params: + if len(self.params["loras"]) > 5: + raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras") + for lora in self.params["loras"]: + if lora.get("is_version") and not lora["name"].isdigit(): + raise e.BadRequest("explicit LoRa version requests have to be a version ID (i.e integer).", rc="BadLoraVersion") + if "tis" in self.params and len(self.params["tis"]) > 20: + raise e.BadRequest("You cannot request more than 20 Textual Inversions per generation.", rc="TooManyTIs") + if self.params.get("transparent", False) is True: + if any( + model_reference.get_model_baseline(model_name) not in ["stable_diffusion_xl", "stable diffusion 1"] + for model_name in self.models + ): + raise e.BadRequest( + "Generating Transparent images is only possible for Stable Diffusion 1.5 and XL models.", + rc="InvalidTransparencyModel", + ) + if self.params.get("workflow") == "qr_code": + if not all( + model_reference.get_model_baseline(model_name) in ["stable diffusion 1", "stable_diffusion_xl"] + for model_name in self.models + ): + raise e.BadRequest("QR Code controlnet only works with SD 1.5 and SDXL models currently", rc="ControlNetMismatch.") + if len(self.prompt.split()) > 7500: + raise e.InvalidPromptSize() + if any(model_name in KNOWN_POST_PROCESSORS for model_name in self.models): + raise e.UnsupportedModel(rc="UnexpectedModelName") + upscaler_count = len([pp for pp in self.params.get("post_processing", []) if pp in KNOWN_UPSCALERS]) + if upscaler_count > 1: + raise e.BadRequest("Cannot use more than 1 upscaler at a time.", rc="TooManyUpscalers") + cfg_scale = self.params.get("cfg_scale") + if cfg_scale is not None: + try: + rounded_cfg_scale = round(cfg_scale, 2) + if rounded_cfg_scale != cfg_scale: + raise e.BadRequest("cfg_scale must be rounded to 2 decimal places", rc="BadCFGDecimals") + except (TypeError, ValueError): + logger.warning( + f"Invalid cfg_scale: {cfg_scale} for when it should be already validated.", + ) + raise e.BadRequest("cfg_scale must be a valid number", rc="BadCFGNumber") + + return self.warnings + + def check_for_special(self): + if not self.user and self.params.get("special"): + raise e.BadRequest("Only special users can send a special field.", "SpecialFieldNeedsSpecialUser") + for model in self.models: + if "horde_special" in model: + if not self.user.special: + raise e.Forbidden("Only special users can request a special model.", "SpecialModelNeedsSpecialUser") + usermodel = model.split("::") + if len(usermodel) == 1: + raise e.BadRequest( + "Special models must always include the username, in the form of 'horde_special::user#id'", + rc="SpecialMissingUsername", + ) + user_alias = usermodel[1] + if self.user.get_unique_alias() != user_alias: + raise e.Forbidden(f"This model can only be requested by {user_alias}", "SpecialForbidden") + if not self.params.get("special"): + raise e.BadRequest("Special models have to include a special payload", rc="SpecialMissingPayload") + + def validate_image_prompt(self, prompt): + if "{p}" not in prompt: + raise e.BadRequest( + "A style prompt must include a dedicated spot where the user's positive prompt will be added, signified with '{p}'", + "StylePromptMissingVars", + ) + if "{np}" not in prompt: + raise e.BadRequest( + "A style prompt must include a dedicated spot where the user's negative prompt will be added, signified with '{np}'", + "StylePromptMissingVars", + ) + + def validate_text_prompt(self, prompt): + if "{p}" not in prompt: + raise e.BadRequest( + "A style prompt must include a dedicated spot where the user's positive prompt will be added, signified with '{p}'", + "StylePromptMissingVars", + ) diff --git a/horde/vars.py b/horde/vars.py index f0a1ae15..f5013f15 100644 --- a/horde/vars.py +++ b/horde/vars.py @@ -42,5 +42,7 @@ horde_title = os.getenv("HORDE_TITLE", "AI Horde") horde_noun = os.getenv("HORDE_noun", "horde") horde_url = os.getenv("HORDE_URL", "https://aihorde.net") +horde_repository = os.getenv("HORDE_REPOSITORY", "https://github.com/Haidra-Org/AI-Horde") +horde_logo = os.getenv("HORDE_LOGO", "https://aihorde.net/assets/img/logo.png") horde_contact_email = os.getenv("HORDE_EMAIL", "aihorde@dbzer0.com") horde_instance_id = str(uuid4()) diff --git a/img_stable/1.jpg b/img_stable/1.jpg new file mode 100644 index 00000000..f939b4f1 Binary files /dev/null and b/img_stable/1.jpg differ diff --git a/img_stable/1.jpg.license b/img_stable/1.jpg.license new file mode 100644 index 00000000..5f7f9564 --- /dev/null +++ b/img_stable/1.jpg.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis + +SPDX-License-Identifier: CC0-1.0 diff --git a/img_stable/2.jpg b/img_stable/2.jpg new file mode 100644 index 00000000..f939b4f1 Binary files /dev/null and b/img_stable/2.jpg differ diff --git a/img_stable/2.jpg.license b/img_stable/2.jpg.license new file mode 100644 index 00000000..5f7f9564 --- /dev/null +++ b/img_stable/2.jpg.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis + +SPDX-License-Identifier: CC0-1.0 diff --git a/img_stable/3.jpg b/img_stable/3.jpg new file mode 100644 index 00000000..f939b4f1 Binary files /dev/null and b/img_stable/3.jpg differ diff --git a/img_stable/3.jpg.license b/img_stable/3.jpg.license new file mode 100644 index 00000000..5f7f9564 --- /dev/null +++ b/img_stable/3.jpg.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis + +SPDX-License-Identifier: CC0-1.0 diff --git a/requirements.txt b/requirements.txt index 3e799de4..dacef9c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,9 +20,9 @@ better_profanity python-dateutil~=2.8.2 redis~=4.3.5 pillow>=10.0.1 -flask_sqlalchemy==3.0.2 +flask_sqlalchemy~=3.1.1 oauthlib~=3.2.2 -SQLAlchemy~=1.4.44 +SQLAlchemy~=2.0.35 psycopg2-binary boto3>=1.33.7 regex @@ -33,4 +33,4 @@ torch emoji semver >= 3.0.2 numpy ~= 1.26.4 # better_profanity fails on later versions of numpy -markdownify +markdownify diff --git a/sql_statements/4.41.0.txt b/sql_statements/4.41.0.txt new file mode 100644 index 00000000..c77e8819 --- /dev/null +++ b/sql_statements/4.41.0.txt @@ -0,0 +1 @@ +ALTER TABLE waiting_prompts ADD COLUMN validated_backends BOOLEAN default true; diff --git a/sql_statements/4.41.0.txt.license b/sql_statements/4.41.0.txt.license new file mode 100644 index 00000000..8140c6e2 --- /dev/null +++ b/sql_statements/4.41.0.txt.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Konstantinos Thoukydidis + +SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/sql_statements/4.43.0.txt b/sql_statements/4.43.0.txt new file mode 100644 index 00000000..9fb608d7 --- /dev/null +++ b/sql_statements/4.43.0.txt @@ -0,0 +1,13 @@ +ALTER TABLE waiting_prompts ADD COLUMN extra_slow_workers BOOLEAN NOT NULL default false; +ALTER TABLE workers ADD COLUMN extra_slow_worker BOOLEAN NOT NULL default false; +ALTER TABLE workers ADD COLUMN limit_max_steps BOOLEAN NOT NULL default false; +ALTER TABLE processing_gens ADD COLUMN job_ttl INTEGER NOT NULL DEFAULT 150; +CREATE INDEX idx_processing_gens_job_ttl ON public.processing_gens USING btree(job_ttl); +CREATE INDEX idx_waiting_prompts_extra_slow_workers ON public.waiting_prompts USING btree(extra_slow_workers); +CREATE INDEX idx_workers_extra_slow_worker ON public.workers USING btree(extra_slow_worker); +CREATE INDEX idx_workers_allow_img2img ON public.workers USING btree(allow_img2img); +CREATE INDEX idx_workers_allow_painting ON public.workers USING btree(allow_painting); +CREATE INDEX idx_workers_allow_post_processing ON public.workers USING btree(allow_post_processing); +CREATE INDEX idx_workers_allow_controlnet ON public.workers USING btree(allow_controlnet); +CREATE INDEX idx_workers_allow_sdxl_controlnet ON public.workers USING btree(allow_sdxl_controlnet); +CREATE INDEX idx_workers_allow_lora ON public.workers USING btree(allow_lora); diff --git a/sql_statements/4.43.0.txt.license b/sql_statements/4.43.0.txt.license new file mode 100644 index 00000000..8140c6e2 --- /dev/null +++ b/sql_statements/4.43.0.txt.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Konstantinos Thoukydidis + +SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/sql_statements/4.44.0.txt b/sql_statements/4.44.0.txt new file mode 100644 index 00000000..8d8c8dd1 --- /dev/null +++ b/sql_statements/4.44.0.txt @@ -0,0 +1,2 @@ +ALTER TYPE userrecordtypes ADD VALUE 'STYLE'; +alter sequence interrogation_worker_forms_id_seq cycle; diff --git a/sql_statements/4.44.0.txt.license b/sql_statements/4.44.0.txt.license new file mode 100644 index 00000000..8140c6e2 --- /dev/null +++ b/sql_statements/4.44.0.txt.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Konstantinos Thoukydidis + +SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/sql_statements/README.md b/sql_statements/README.md index 04ae9af8..efa8b143 100644 --- a/sql_statements/README.md +++ b/sql_statements/README.md @@ -22,4 +22,4 @@ SPDX-License-Identifier: AGPL-3.0-or-later - `compile_*gen_stats_*.sql` - These files defined stored procedures which populated the `compiled_*` tables and generally represent minute/hour/day/total statistics about generations. - `cron_jobs/` - - Schedules any stats compile jobs via `schedule_cron_job`. \ No newline at end of file + - Schedules any stats compile jobs via `schedule_cron_job`. diff --git a/sql_statements/stored_procedures/compile_imagegen_stats_models.sql b/sql_statements/stored_procedures/compile_imagegen_stats_models.sql index a4e52933..c662b6ae 100644 --- a/sql_statements/stored_procedures/compile_imagegen_stats_models.sql +++ b/sql_statements/stored_procedures/compile_imagegen_stats_models.sql @@ -1,4 +1,3 @@ --- SPDX-FileCopyrightText: 2024 2022 Tazlin -- SPDX-FileCopyrightText: 2024 Tazlin -- -- SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql b/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql index 0dd7b9b5..6a379e49 100644 --- a/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql +++ b/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql @@ -1,4 +1,3 @@ --- SPDX-FileCopyrightText: 2024 2022 Tazlin -- SPDX-FileCopyrightText: 2024 Tazlin -- -- SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/sql_statements/stored_procedures/compile_textgen_stats_models.sql b/sql_statements/stored_procedures/compile_textgen_stats_models.sql index f6992cab..7bb5d867 100644 --- a/sql_statements/stored_procedures/compile_textgen_stats_models.sql +++ b/sql_statements/stored_procedures/compile_textgen_stats_models.sql @@ -1,4 +1,3 @@ --- SPDX-FileCopyrightText: 2024 2022 Tazlin -- SPDX-FileCopyrightText: 2024 Tazlin -- -- SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/tests/test_image.py b/tests/test_image.py index 109d0ca7..48f35167 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later -import json + +import time import requests @@ -10,6 +11,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: + print("test_simple_image_gen") headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user async_dict = { "prompt": "a horde of cute stable robots in a sprawling server room repairing a massive mainframe", @@ -37,11 +39,10 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: async_results = async_req.json() req_id = async_results["id"] # print(async_results) - print(async_results) pop_dict = { "name": "CICD Fake Dreamer", "models": TEST_MODELS, - "bridge_agent": "AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen", + "bridge_agent": "AI Horde Worker reGen:9.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen", "nsfw": True, "amount": 10, "max_pixels": 4194304, @@ -55,14 +56,14 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: } pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) try: - print(pop_req.text) + # print(pop_req.text) assert pop_req.ok, pop_req.text except AssertionError as err: requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) print("Request cancelled") raise err pop_results = pop_req.json() - print(json.dumps(pop_results, indent=4)) + # print(json.dumps(pop_results, indent=4)) job_id = pop_results["id"] try: @@ -84,7 +85,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) assert retrieve_req.ok, retrieve_req.text retrieve_results = retrieve_req.json() - print(json.dumps(retrieve_results, indent=4)) + # print(json.dumps(retrieve_results, indent=4)) assert len(retrieve_results["generations"]) == 1 gen = retrieve_results["generations"][0] assert len(gen["gen_metadata"]) == 0 @@ -94,8 +95,182 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: assert gen["state"] == "ok" assert retrieve_results["kudos"] > 1 assert retrieve_results["done"] is True + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + + +TEST_MODELS_FLUX = ["Flux.1-Schnell fp8 (Compact)"] + + +def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: + print("test_flux_image_gen") + headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user + async_dict = { + "prompt": "a horde of cute flux robots in a sprawling server room repairing a massive mainframe", + "nsfw": True, + "censor_nsfw": False, + "r2": True, + "shared": True, + "trusted_workers": True, + "params": { + "width": 1024, + "height": 1024, + "steps": 8, + "cfg_scale": 1, + "sampler_name": "k_euler", + }, + "models": TEST_MODELS_FLUX, + # "extra_slow_workers": True, + } + protocol = "http" + if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]: + protocol = "https" + time.sleep(1) + async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/async", json=async_dict, headers=headers) + assert async_req.ok, async_req.text + async_results = async_req.json() + req_id = async_results["id"] + # print(async_results) + pop_dict = { + "name": "CICD Fake Dreamer", + "models": TEST_MODELS_FLUX, + "bridge_agent": "AI Horde Worker reGen:9.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen", + "nsfw": True, + "amount": 10, + "max_pixels": 4194304, + "allow_img2img": True, + "allow_painting": True, + "allow_unsafe_ipaddr": True, + "allow_post_processing": True, + "allow_controlnet": True, + "allow_sdxl_controlnet": True, + "allow_lora": True, + "extra_slow_worker": False, + "limit_max_steps": True, + } + + # Test limit_max_steps + pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) + try: + # print(pop_req.text) + assert pop_req.ok, pop_req.text + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + pop_results = pop_req.json() + # print(json.dumps(pop_results, indent=4)) + try: + assert pop_results["id"] is None, pop_results + assert pop_results["skipped"].get("step_count") == 1, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + + # Test extra_slow_worker + async_dict["params"]["steps"] = 5 + pop_dict["extra_slow_worker"] = True + time.sleep(0.5) + pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) + try: + # print(pop_req.text) + assert pop_req.ok, pop_req.text + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + pop_results = pop_req.json() + # print(json.dumps(pop_results, indent=4)) + try: + assert pop_results["id"] is None, pop_results + assert pop_results["skipped"]["performance"] == 1, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + + # Try popping as an extra slow worker + async_dict["extra_slow_workers"] = True + time.sleep(0.5) + async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/async", json=async_dict, headers=headers) + assert async_req.ok, async_req.text + async_results = async_req.json() + req_id = async_results["id"] + time.sleep(0.5) + pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) + try: + assert pop_req.ok, pop_req.text + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + pop_results = pop_req.json() + job_id = pop_results["id"] + try: + assert job_id is not None, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + submit_dict = { + "id": job_id, + "generation": "R2", + "state": "ok", + "seed": 0, + } + submit_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/submit", json=submit_dict, headers=headers) + assert submit_req.ok, submit_req.text + submit_results = submit_req.json() + assert submit_results["reward"] > 0 + retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + assert retrieve_req.ok, retrieve_req.text + retrieve_results = retrieve_req.json() + # print(json.dumps(retrieve_results, indent=4)) + assert len(retrieve_results["generations"]) == 1 + gen = retrieve_results["generations"][0] + assert len(gen["gen_metadata"]) == 0 + assert gen["seed"] == "0" + assert gen["worker_name"] == "CICD Fake Dreamer" + assert gen["model"] in TEST_MODELS_FLUX + assert gen["state"] == "ok" + assert retrieve_results["kudos"] > 1 + assert retrieve_results["done"] is True + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + + +def quick_pop(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: + print("quick_pop") + headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user + protocol = "http" + if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]: + protocol = "https" + # print(async_results) + pop_dict = { + "name": "CICD Fake Dreamer", + "models": TEST_MODELS_FLUX, + "bridge_agent": "AI Horde Worker reGen:9.1.0-citests:https://github.com/Haidra-Org/horde-worker-reGen", + "nsfw": True, + "amount": 10, + "max_pixels": 4194304, + "allow_img2img": True, + "allow_painting": True, + "allow_unsafe_ipaddr": True, + "allow_post_processing": True, + "allow_controlnet": True, + "allow_sdxl_controlnet": True, + "allow_lora": True, + "extra_slow_worker": False, + "limit_max_steps": True, + } + + # Test limit_max_steps + pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) + print(pop_req.text) if __name__ == "__main__": # "ci/cd#12285" - test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1") + test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0") + test_flux_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0") + # quick_pop("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0") diff --git a/tests/test_text.py b/tests/test_text.py index 4c99a527..98522f38 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -12,6 +12,7 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: async_dict = { "prompt": "a horde of cute stable robots in a sprawling server room repairing a massive mainframe", "trusted_workers": True, + "validated_backends": False, "max_length": 512, "max_context_length": 2048, "temperature": 1, @@ -37,9 +38,13 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: assert pop_req.ok, pop_req.text pop_results = pop_req.json() # print(json.dumps(pop_results, indent=4)) - job_id = pop_results["id"] - assert job_id is not None, pop_results + try: + assert job_id is not None, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/text/status/{req_id}", headers=headers) + print("Request cancelled") + raise err submit_dict = { "id": job_id, "generation": "test ",