diff --git a/horde/apis/v2/kobold.py b/horde/apis/v2/kobold.py index 2c6031a1..a38b35ac 100644 --- a/horde/apis/v2/kobold.py +++ b/horde/apis/v2/kobold.py @@ -7,8 +7,8 @@ from horde.apis.v2.base import GenerateTemplate, JobPopTemplate, JobSubmitTemplate, api from horde.classes.base import settings from horde.classes.kobold.genstats import ( - compile_textgen_stats_models, - compile_textgen_stats_totals, + get_compiled_textgen_stats_models, + get_compiled_textgen_stats_totals, ) from horde.classes.kobold.waiting_prompt import TextWaitingPrompt from horde.classes.kobold.worker import TextWorker @@ -356,7 +356,7 @@ def get(self): """Details how many texts have been generated in the past minux,hour,day,month and total Also shows the amount of pixelsteps for the same timeframe. """ - return compile_textgen_stats_totals(), 200 + return get_compiled_textgen_stats_totals(), 200 class TextHordeStatsModels(Resource): @@ -380,7 +380,7 @@ class TextHordeStatsModels(Resource): ) def get(self): """Details how many texts were generated per model for the past day, month and total""" - return compile_textgen_stats_models(), 200 + return get_compiled_textgen_stats_models(), 200 class KoboldKudosTransfer(Resource): diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 6b441385..150128c2 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -12,8 +12,8 @@ from horde.classes.base import settings from horde.classes.base.user import User from horde.classes.stable.genstats import ( - compile_imagegen_stats_models, - compile_imagegen_stats_totals, + get_compiled_imagegen_stats_models, + get_compiled_imagegen_stats_totals, ) from horde.classes.stable.interrogation import Interrogation from horde.classes.stable.interrogation_worker import InterrogationWorker @@ -573,6 +573,7 @@ def post(self): if "blacklist" in post_ret.get("skipped", {}): db_skipped["blacklist"] = post_ret["skipped"]["blacklist"] post_ret["skipped"] = db_skipped + return post_ret, retcode def check_in(self): @@ -1272,7 +1273,7 @@ def get(self): """Details how many images have been generated in the past minux,hour,day,month and total Also shows the amount of pixelsteps for the same timeframe. """ - return compile_imagegen_stats_totals(), 200 + return get_compiled_imagegen_stats_totals(), 200 class ImageHordeStatsModels(Resource): @@ -1312,4 +1313,4 @@ def get(self): self.args = self.get_parser.parse_args() if self.args.model_state not in ["known", "custom", "all"]: raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']") - return compile_imagegen_stats_models(self.args.model_state), 200 + return get_compiled_imagegen_stats_models(self.args.model_state), 200 diff --git a/horde/classes/__init__.py b/horde/classes/__init__.py index 9c38f372..55f5da48 100644 --- a/horde/classes/__init__.py +++ b/horde/classes/__init__.py @@ -1,3 +1,5 @@ +from pathlib import Path + import horde.classes.base.stats # noqa 401 from horde.argparser import args from horde.classes.base.detection import Filter # noqa 401 @@ -10,13 +12,14 @@ from horde.classes.kobold.worker import TextWorker # noqa 401 from horde.classes.stable.interrogation import Interrogation # noqa 401 from horde.classes.stable.interrogation_worker import InterrogationWorker # noqa 401 +from horde.classes.stable.known_image_models import KnownImageModel # noqa 401 # Importing for DB creation - # noqa 401 from horde.classes.stable.waiting_prompt import ImageWaitingPrompt # noqa 401 from horde.classes.stable.worker import ImageWorker # noqa 401 from horde.flask import HORDE, db +from horde.logger import logger from horde.utils import hash_api_key with HORDE.app_context(): @@ -29,6 +32,29 @@ # sys.exit() db.create_all() + sql_statement_dir = Path(__file__).parent.parent.parent / "sql_statements" + + # The order of these directories is important. `cron` creates a stored procedure that is + # used by queries in all other `cron_jobs/` directories. + all_dirs_to_run = [ + "cron/", # Must be first + "stored_procedures/", + "stored_procedures/cron_jobs/", + ] + + all_dirs_to_run = [sql_statement_dir / dir for dir in all_dirs_to_run] + + with logger.catch(reraise=True): + for dir in all_dirs_to_run: + logger.info(f"Running files in {dir}") + for file in dir.iterdir(): + if file.suffix == ".sql": + logger.info(f"Running {file}") + with file.open() as f: + db.session.execute(f.read()) + + db.session.commit() + if args.convert_flag == "roles": # from horde.conversions import convert_user_roles @@ -65,6 +91,7 @@ "TextWaitingPrompt", "Interrogation", "InterrogationWorker", + "KnownImageModel", "User", "Team", "ImageWorker", diff --git a/horde/classes/base/user.py b/horde/classes/base/user.py index 062b3a80..a11a8172 100644 --- a/horde/classes/base/user.py +++ b/horde/classes/base/user.py @@ -230,7 +230,7 @@ class User(db.Model): oauth_id = db.Column(db.String(50), unique=True, nullable=False, index=True) api_key = db.Column(db.String(100), unique=True, nullable=False, index=True) client_id = db.Column(db.String(50), unique=True, default=generate_client_id, nullable=False) - created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True) last_active = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) contact = db.Column(db.String(50), default=None) admin_comment = db.Column(db.Text, default=None) diff --git a/horde/classes/kobold/genstats.py b/horde/classes/kobold/genstats.py index e85adaa8..f6fc9b6e 100644 --- a/horde/classes/kobold/genstats.py +++ b/horde/classes/kobold/genstats.py @@ -1,27 +1,11 @@ -from datetime import datetime, timedelta +from datetime import datetime -from sqlalchemy import Enum, func +from sqlalchemy import Enum from horde.enums import ImageGenState from horde.flask import db -class TextGenerationStatistic(db.Model): - __tablename__ = "text_gen_stats" - id = db.Column(db.Integer, primary_key=True) - finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow) - # Created comes from the procgen - created = db.Column(db.DateTime(timezone=False), nullable=True) - model = db.Column(db.String(255), index=True, nullable=False) - max_length = db.Column(db.Integer, nullable=False) - max_context_length = db.Column(db.Integer, nullable=False) - softprompt = db.Column(db.Integer, nullable=True) - prompt_length = db.Column(db.Integer, nullable=False) - client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True) - bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True) - state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True) - - def record_text_statistic(procgen): state = ImageGenState.OK # Currently there's no way to record cancelled images, but maybe there will be in the future @@ -44,71 +28,83 @@ def record_text_statistic(procgen): db.session.commit() -def compile_textgen_stats_totals(): - count_query = db.session.query(TextGenerationStatistic) - count_minute = count_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1), - ).count() - count_hour = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1)).count() - count_day = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1)).count() - count_month = count_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), - ).count() - count_total = count_query.count() - tokens_query = db.session.query(func.sum(TextGenerationStatistic.max_length)) - tokens_minute = tokens_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1), - ).scalar() - tokens_hour = tokens_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1), - ).scalar() - tokens_day = tokens_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1), - ).scalar() - tokens_month = tokens_query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), - ).scalar() - tokens_total = tokens_query.scalar() - stats_dict = { - "minute": { - "requests": count_minute, - "tokens": tokens_minute, - }, - "hour": { - "requests": count_hour, - "tokens": tokens_hour, - }, - "day": { - "requests": count_day, - "tokens": tokens_day, - }, - "month": { - "requests": count_month, - "tokens": tokens_month, - }, - "total": { - "requests": count_total, - "tokens": tokens_total, - }, - } +class TextGenerationStatistic(db.Model): + __tablename__ = "text_gen_stats" + id = db.Column(db.Integer, primary_key=True) + finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True) + # Created comes from the procgen + created = db.Column(db.DateTime(timezone=False), nullable=True) + model = db.Column(db.String(255), nullable=False, index=True) + max_length = db.Column(db.Integer, nullable=False) + max_context_length = db.Column(db.Integer, nullable=False) + softprompt = db.Column(db.Integer, nullable=True) + prompt_length = db.Column(db.Integer, nullable=False) + client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True) + bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True) + state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True) + + +class CompiledTextGensStatsTotals(db.Model): + __tablename__ = "compiled_text_gen_stats_totals" + id = db.Column(db.Integer, primary_key=True) + created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True) + minute_requests = db.Column(db.Integer, nullable=False) + minute_tokens = db.Column(db.Integer, nullable=False) + hour_requests = db.Column(db.Integer, nullable=False) + hour_tokens = db.Column(db.Integer, nullable=False) + day_requests = db.Column(db.Integer, nullable=False) + day_tokens = db.Column(db.Integer, nullable=False) + month_requests = db.Column(db.Integer, nullable=False) + month_tokens = db.Column(db.Integer, nullable=False) + total_requests = db.Column(db.Integer, nullable=False) + total_tokens = db.Column(db.BigInteger, nullable=False) + + +def get_compiled_textgen_stats_totals() -> dict[str, dict[str, int]]: + """Get the compiled text generation statistics for the minute, hour, day, month, and total periods. + + Returns: + dict[str, dict[str, int]]: A dictionary with the period as the key and the requests and tokens as the values. + """ + query = db.session.query(CompiledTextGensStatsTotals).order_by(CompiledTextGensStatsTotals.created.desc()).first() + + periods = ["minute", "hour", "day", "month", "total"] + stats_dict = {period: {"requests": 0, "tokens": 0} for period in periods} + + if query: + for period in periods: + stats_dict[period]["requests"] = getattr(query, f"{period}_requests") + stats_dict[period]["tokens"] = getattr(query, f"{period}_tokens") + return stats_dict -def compile_textgen_stats_models(): - query = db.session.query(TextGenerationStatistic.model, func.count()).group_by(TextGenerationStatistic.model) - ret_dict = { - "total": {model: count for model, count in query.all()}, - "day": { - model: count - for model, count in query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1), - ).all() - }, - "month": { - model: count - for model, count in query.filter( - TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), - ).all() - }, - } - return ret_dict +class CompiledTextGenStatsModels(db.Model): + __tablename__ = "compiled_text_gen_stats_models" + id = db.Column(db.Integer, primary_key=True) + created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True) + model = db.Column(db.String(255), nullable=False, index=True) + day_requests = db.Column(db.Integer, nullable=False) + month_requests = db.Column(db.Integer, nullable=False) + total_requests = db.Column(db.Integer, nullable=False) + + +def get_compiled_textgen_stats_models() -> dict[str, dict[str, int]]: + """Get the compiled text generation statistics for the day, month, and total periods for each model. + + Returns: + dict[str, dict[str, int]]: A dictionary with the model as the key and the requests as the values. + """ + + models: tuple[CompiledTextGenStatsModels] = ( + db.session.query(CompiledTextGenStatsModels).order_by(CompiledTextGenStatsModels.created.desc()).all() + ) + + periods = ["day", "month", "total"] + stats = {period: {model.model: 0 for model in models} for period in periods} + + for model in models: + for period in periods: + stats[period][model.model] = getattr(model, f"{period}_requests") + + return stats diff --git a/horde/classes/stable/genstats.py b/horde/classes/stable/genstats.py index a1433e34..8c217333 100644 --- a/horde/classes/stable/genstats.py +++ b/horde/classes/stable/genstats.py @@ -1,10 +1,9 @@ -from datetime import datetime, timedelta +from datetime import datetime -from sqlalchemy import Enum, func +from sqlalchemy import Enum from horde.enums import ImageGenState from horde.flask import db -from horde.model_reference import model_reference class ImageGenerationStatisticPP(db.Model): @@ -58,7 +57,7 @@ class ImageGenerationStatisticTI(db.Model): class ImageGenerationStatistic(db.Model): __tablename__ = "image_gen_stats" id = db.Column(db.Integer, primary_key=True) - finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow) + finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True) # Created comes from the procgen created = db.Column(db.DateTime(timezone=False), nullable=True) model = db.Column(db.String(255), index=True, nullable=False) @@ -161,78 +160,97 @@ def record_image_statistic(procgen): db.session.commit() -def compile_imagegen_stats_totals(): - count_query = db.session.query(ImageGenerationStatistic) - count_minute = count_query.filter( - ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1), - ).count() - count_hour = count_query.filter( - ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1), - ).count() - count_day = count_query.filter(ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1)).count() - count_month = count_query.filter( - ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), - ).count() - count_total = count_query.count() - ps_query = db.session.query( - func.sum(ImageGenerationStatistic.width * ImageGenerationStatistic.height * ImageGenerationStatistic.steps), - ) - ps_minute = ps_query.filter(ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1)).scalar() - ps_hour = ps_query.filter(ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1)).scalar() - ps_day = ps_query.filter(ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1)).scalar() - ps_month = ps_query.filter(ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30)).scalar() - ps_total = ps_query.scalar() - stats_dict = { - "minute": { - "images": count_minute, - "ps": ps_minute, - }, - "hour": { - "images": count_hour, - "ps": ps_hour, - }, - "day": { - "images": count_day, - "ps": ps_day, - }, - "month": { - "images": count_month, - "ps": ps_month, - }, - "total": { - "images": count_total, - "ps": ps_total, - }, - } - return stats_dict - - -def compile_imagegen_stats_models(model_state="known"): - query = db.session.query(ImageGenerationStatistic.model, func.count()).group_by(ImageGenerationStatistic.model) - - def check_model_state(model_name): - if model_state == "known" and model_reference.is_known_image_model(model_name): - return True - if model_state == "custom" and not model_reference.is_known_image_model(model_name): - return True - if model_state == "all": - return True - return False - - return { - "total": {model: count for model, count in query.all() if check_model_state(model)}, - "day": { - model: count - for model, count in query.filter( - ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1), - ).all() - if check_model_state(model) - }, - "month": { - model: count - for model, count in query.filter( - ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), - ).all() - if check_model_state(model) - }, - } +class CompiledImageGenStatsTotals(db.Model): + """A table to store the compiled image generation statistics for the minute, hour, day, month, and total periods.""" + + __tablename__ = "compiled_image_gen_stats_totals" + id = db.Column(db.Integer, primary_key=True) + created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True) + minute_images = db.Column(db.Integer, nullable=False) + minute_pixels = db.Column(db.Integer, nullable=False) + hour_images = db.Column(db.Integer, nullable=False) + hour_pixels = db.Column(db.Integer, nullable=False) + day_images = db.Column(db.Integer, nullable=False) + day_pixels = db.Column(db.Integer, nullable=False) + month_images = db.Column(db.Integer, nullable=False) + month_pixels = db.Column(db.Integer, nullable=False) + total_images = db.Column(db.Integer, nullable=False) + total_pixels = db.Column(db.BigInteger, nullable=False) + + +def get_compiled_imagegen_stats_totals() -> dict[str, dict[str, int]]: + """Get the precompiled image generation statistics the minute, hour, day, month, and total periods. + + Returns: + dict[str, dict[str, int]]: A dictionary containing the number of images and pixels generated for each period. + """ + + latest_entry = db.session.query(CompiledImageGenStatsTotals).order_by(CompiledImageGenStatsTotals.created.desc()).first() + + periods = ["minute", "hour", "day", "month", "total"] + stats = {period: {"images": 0, "ps": 0} for period in periods} + + if latest_entry: + for period in periods: + stats[period]["images"] = getattr(latest_entry, f"{period}_images") + stats[period]["ps"] = getattr(latest_entry, f"{period}_pixels") + + return stats + + +class CompiledImageGenStatsModels(db.Model): + """A table to store the compiled image generation statistics for each model.""" + + __tablename__ = "compiled_image_gen_stats_models" + id = db.Column(db.Integer, primary_key=True) + created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True, nullable=False) + model_id = db.Column(db.Integer, db.ForeignKey("known_image_models.id"), nullable=True) + model = db.relationship("KnownImageModel", backref=db.backref("known_image_models", lazy=True)) + model_name = db.Column(db.String(255), nullable=False) + model_state = db.Column(db.String(16), nullable=False) + day_images = db.Column(db.Integer, nullable=False) + month_images = db.Column(db.Integer, nullable=False) + total_images = db.Column(db.Integer, nullable=False) + + +def get_compiled_imagegen_stats_models(model_state: str = "all") -> dict[str, dict[str, dict[str, int]]]: + """Gets the precompiled image generation statistics for the day, month, and total periods for each model.""" + + models: tuple[CompiledImageGenStatsModels] = () + + # If model_state is "all" we get all models, if it's "known" we get only known models, if it's "custom" we get only custom models + if model_state == "all": + models = db.session.query(CompiledImageGenStatsModels.model_name).distinct().all() + elif model_state == "known": + models = ( + db.session.query(CompiledImageGenStatsModels.model_name) + .filter(CompiledImageGenStatsModels.model_state == "known") + .distinct() + .all() + ) + elif model_state == "custom": + models = ( + db.session.query(CompiledImageGenStatsModels.model_name) + .filter(CompiledImageGenStatsModels.model_state == "custom") + .distinct() + .all() + ) + else: + raise ValueError("Invalid model_state. Expected 'all', 'known', or 'custom'.") + + periods = ["day", "month", "total"] + stats = {period: {model.model_name: 0 for model in models} for period in periods} + + for model in models: + latest_entry = ( + db.session.query(CompiledImageGenStatsModels) + .filter_by(model_name=model.model_name) + .order_by(CompiledImageGenStatsModels.created.desc()) + .first() + ) + + if latest_entry: + for period in periods: + stats[period][model.model_name] = getattr(latest_entry, f"{period}_images") + + return stats diff --git a/horde/classes/stable/known_image_models.py b/horde/classes/stable/known_image_models.py new file mode 100644 index 00000000..08ac6d2f --- /dev/null +++ b/horde/classes/stable/known_image_models.py @@ -0,0 +1,219 @@ +from datetime import datetime +from typing import Union + +from horde.flask import db +from horde.logger import logger + + +class KnownImageModel(db.Model): + """The schema for the known image models database table.""" + + __tablename__ = "known_image_models" + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(128), nullable=False) + baseline = db.Column(db.String(128), nullable=False) + """The baseline of the model. For example, 'stable diffusion 1' or 'stable_diffusion_xl`.""" + inpainting = db.Column(db.Boolean, nullable=False) + description = db.Column(db.String(512), nullable=True) + version = db.Column(db.String(16), nullable=False) + style = db.Column(db.String(64), nullable=False) + tags = db.Column(db.JSON, nullable=False) + homepage = db.Column(db.String(256), nullable=True) + nsfw = db.Column(db.Boolean, nullable=False) + requirements = db.Column(db.JSON, nullable=True) + config = db.Column(db.JSON, nullable=False) + features_not_supported = db.Column(db.JSON, nullable=True) + size_on_disk_bytes = db.Column(db.BigInteger, nullable=True) + """The size of the model on disk in bytes.""" + created_at = db.Column(db.DateTime, default=datetime.utcnow) + """The time the model was added to the database.""" + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + """The time the model was last updated in the database.""" + + +@logger.catch(reraise=True) +def get_known_image_models() -> list[KnownImageModel]: + """Get all known image models from the database.""" + return db.session.query(KnownImageModel).all() + + +@logger.catch(reraise=True) +def is_model_known(model_name: Union[KnownImageModel, str]) -> bool: + """Check if a model is known in the database. + + Args: + model_name (str): The name of the model to check. + + Returns: + bool: Whether the model is known. + """ + if isinstance(model_name, KnownImageModel): + model_name = model_name.name + + return db.session.query(KnownImageModel).filter(KnownImageModel.name == model_name).first() is not None + + +@logger.catch(reraise=True) +def add_known_image_model( + name: str, + baseline: str, + inpainting: bool, + description: str, + version: str, + style: str, + tags: list[str], + homepage: str, + nsfw: bool, + requirements: dict, + config: dict, + features_not_supported: list[str], + size_on_disk_bytes: int, + *, + defer_commit: bool = False, +) -> None: + """Add an image model to the database. This function will update the model if it already exists. + + Note that the arguments of this function reflect those found in the model reference JSON. + + Args: + name (str): The name of the model. + baseline (str): The baseline model used. + inpainting (bool): Whether the model is capable of inpainting. + description (str): A description of the model. + version (str): The version of the model. + style (str): The style of the model. + tags (list[str]): A list of tags for the model. + homepage (str): The homepage of the model. + nsfw (bool): Whether the model is NSFW. + requirements (dict): The requirements of the model. + config (dict): The configuration of the model. + features_not_supported (list[str]): A list of features not supported by the model. + size_on_disk_bytes (int): The size of the model on disk. + + defer_commit (bool): Whether to defer committing the addition to the database. + """ + + model: Union[KnownImageModel, None] = db.session.query(KnownImageModel).filter(KnownImageModel.name == name).first() + + if model: + model.baseline = baseline + model.inpainting = inpainting + model.description = description + model.version = version + model.style = style + model.tags = tags + model.homepage = homepage + model.nsfw = nsfw + model.requirements = requirements + model.config = config + model.features_not_supported = features_not_supported + model.size_on_disk_bytes = size_on_disk_bytes + else: + logger.info(f"Attempting to add new known image model: {name}") + model = KnownImageModel( + name=name, + baseline=baseline, + inpainting=inpainting, + description=description, + version=version, + style=style, + tags=tags, + homepage=homepage, + nsfw=nsfw, + requirements=requirements, + config=config, + features_not_supported=features_not_supported, + size_on_disk_bytes=size_on_disk_bytes, + ) + db.session.add(model) + + if not defer_commit: + db.session.commit() + + +@logger.catch(reraise=True) +def add_known_image_model_from_json(json: dict[str, object], defer_commit: bool = False) -> None: + """Add a image model to the database from a JSON object. + + Args: + json (dict[str, object]): The model reference JSON object. + defer_commit (bool): Whether to defer committing the addition to the database. + + """ + add_known_image_model( + name=json.get("name"), + baseline=json.get("baseline"), + inpainting=json.get("inpainting"), + description=json.get("description"), + version=json.get("version"), + style=json.get("style"), + tags=json.get("tags"), + homepage=json.get("homepage"), + nsfw=json.get("nsfw"), + requirements=json.get("requirements"), + config=json.get("config"), + features_not_supported=json.get("features_not_supported"), + size_on_disk_bytes=json.get("size_on_disk_bytes"), + defer_commit=defer_commit, + ) + + +@logger.catch(reraise=True) +def add_known_image_models_from_json(json: dict[str, dict]) -> None: + """Add multiple image models to the database from a JSON object. + + Args: + json (dict[str, dict]): The model reference JSON object. + """ + for model in json.values(): + add_known_image_model_from_json(model, defer_commit=True) + + db.session.commit() + logger.info(f"Added (or updated) {len(json)} known image models.") + + +@logger.catch(reraise=True) +def delete_known_image_model(model_name: str, defer_commit: bool = False) -> bool: + """Attempt to delete a known image model from the database. + + Args: + model_name (str): Name of the model to delete. + defer_commit (bool): Whether to defer committing the deletion to the database. + + Returns: + bool: Whether the model was deleted, or if defer_commit is True, whether the model was found and queued for deletion. + """ + model = db.session.query(KnownImageModel).filter(KnownImageModel.name == model_name).first() + if model: + db.session.delete(model) + logger.info(f"Queueing deletion of known image model: {model_name}") + if not defer_commit: + db.session.commit() + + return True + else: + logger.error(f"Model {model_name} not found in the database") + + return False + + +@logger.catch(reraise=True) +def delete_any_unspecified_image_models(models_desired: list[str]) -> None: + """Delete any models not specified in the list from the database. + + Args: + models_desired (list[str]): List of model names to keep in the database. + """ + models_records_in_db = db.session.query(KnownImageModel).all() + model_names_in_db = [model.name for model in models_records_in_db] + num_deleted = 0 + for model in model_names_in_db: + if model not in models_desired: + was_deleted = delete_known_image_model(model, defer_commit=True) + if was_deleted: + num_deleted += 1 + + if num_deleted > 0: + logger.info(f"Deleted {num_deleted} models from the database") + + db.session.commit() diff --git a/horde/database/__init__.py b/horde/database/__init__.py index 7b9aa7cf..c7e61f51 100644 --- a/horde/database/__init__.py +++ b/horde/database/__init__.py @@ -19,6 +19,7 @@ priority_increaser = PrimaryTimedFunction(10, threads.increment_extra_priority, quorum=quorum) compiled_filter_cacher = PrimaryTimedFunction(10, threads.store_compiled_filter_regex, quorum=quorum) regex_replacements_cacher = PrimaryTimedFunction(10, threads.store_compiled_filter_regex_replacements, quorum=quorum) +known_image_models_cacher = PrimaryTimedFunction(300, threads.store_known_image_models, quorum=quorum) if args.reload_all_caches: logger.info("store_prioritized_wp_queue()") @@ -35,6 +36,8 @@ threads.store_compiled_filter_regex_replacements() logger.info("store_available_models()") threads.store_available_models() + logger.info("store_known_image_models()") + threads.store_known_image_models() if args.check_prompts: diff --git a/horde/database/threads.py b/horde/database/threads.py index 8b7bfa88..3689a562 100644 --- a/horde/database/threads.py +++ b/horde/database/threads.py @@ -396,3 +396,22 @@ def store_compiled_filter_regex_replacements(): replacements = retrieve_regex_replacements(10) # We don't expire filters once set, to avoid ever losing the cache and letting prompts through hr.horde_r_set("cached_regex_replacements", json.dumps(replacements)) + + +@logger.catch(reraise=True) +def store_known_image_models(): + """Stores the known image models in the database""" + from horde.classes.stable.known_image_models import ( + add_known_image_models_from_json, + delete_any_unspecified_image_models, + ) + from horde.model_reference import model_reference + + with HORDE.app_context(): + if model_reference.reference is not None: + logger.debug("Storing known image models from the model reference") + add_known_image_models_from_json(model_reference.reference) + delete_any_unspecified_image_models(list(model_reference.reference.keys())) + + else: + logger.debug("No known image models to store from the model reference") diff --git a/horde/model_reference.py b/horde/model_reference.py index a6459469..82e65a41 100644 --- a/horde/model_reference.py +++ b/horde/model_reference.py @@ -55,9 +55,11 @@ def call_function(self): self.nsfw_models.add(model) if self.reference[model].get("type") == "controlnet": self.controlnet_models.add(model) + break except Exception as e: logger.error(f"Error when downloading nataili models list: {e}") + for _riter in range(10): try: self.text_reference = requests.get( diff --git a/sql_statements/README.md b/sql_statements/README.md new file mode 100644 index 00000000..e1e651d2 --- /dev/null +++ b/sql_statements/README.md @@ -0,0 +1,19 @@ +## AI-Horde Database Information + +- postgresql >=15 +- [pg_cron](https://github.com/citusdata/pg_cron) + + +## `pg_cron` config + +> **Warning**: All `.sql` files found in a directory deeper than `sql_statements/` will be dynamically run, not only the ones specifically identified in this document. Only place `.sql` files you intend to run in these directories. This does not apply to the `sql_statements` level (i.e., `sql_statements/4.35.1.sql` is not automatically run, but `sql_statements/cron/your_new_file.sql` will be.) + +- `cron/` + - `schedule_cron_job.sql` + - Creates a stored procedure which schedules a new pg_cron job to execute a specified stored procedure at intervals defined by a cron schedule string, **if a job with the same command doesn't already exist**. + - e.g., `CALL schedule_cron_job('0-59 * * * *', 'compile_imagegen_stats_totals');` +- `stored_procedures` + - `compile_*gen_stats_*.sql` + - These files defined stored procedures which populated the `compiled_*` tables and generally represent minute/hour/day/total statistics about generations. + - `cron_jobs/` + - Schedules any stats compile jobs via `schedule_cron_job`. \ No newline at end of file diff --git a/sql_statements/cron/schedule_cron_job.sql b/sql_statements/cron/schedule_cron_job.sql new file mode 100644 index 00000000..e4ed4d01 --- /dev/null +++ b/sql_statements/cron/schedule_cron_job.sql @@ -0,0 +1,33 @@ +CREATE EXTENSION IF NOT EXISTS pg_cron; + +CREATE OR REPLACE PROCEDURE schedule_cron_job( + p_schedule TEXT, + p_stored_procedure TEXT +) +LANGUAGE plpgsql +AS $$ +DECLARE + existing_schedule TEXT; + existing_jobid INT; +BEGIN + SET search_path TO cron, public; + + -- Get the existing schedule and jobid for the stored procedure + SELECT schedule, jobid + INTO existing_schedule, existing_jobid + FROM cron.job + WHERE command = format($CRON$ CALL %s(); $CRON$, p_stored_procedure); + + -- If the job exists and the schedules don't match, update it + IF FOUND AND existing_schedule <> p_schedule THEN + PERFORM cron.unschedule(existing_jobid); + PERFORM cron.schedule(p_schedule, format($CRON$ CALL %s(); $CRON$, p_stored_procedure)); + RAISE NOTICE 'Cron job schedule updated successfully for stored procedure: %', p_stored_procedure; + -- If the job doesn't exist, schedule it + ELSIF NOT FOUND THEN + PERFORM cron.schedule(p_schedule, format($CRON$ CALL %s(); $CRON$, p_stored_procedure)); + RAISE NOTICE 'Cron job scheduled successfully for stored procedure: %', p_stored_procedure; + ELSE + RAISE NOTICE 'Cron job already exists with the same schedule for stored procedure: %. Skipping scheduling.', p_stored_procedure; + END IF; +END $$; diff --git a/sql_statements/stored_procedures/compile_imagegen_stats_models.sql b/sql_statements/stored_procedures/compile_imagegen_stats_models.sql new file mode 100644 index 00000000..0546626f --- /dev/null +++ b/sql_statements/stored_procedures/compile_imagegen_stats_models.sql @@ -0,0 +1,33 @@ +CREATE OR REPLACE PROCEDURE compile_imagegen_stats_models() +LANGUAGE plpgsql +AS $$ +BEGIN + WITH model_stats AS ( + SELECT + kim.id as model_id, + igs.model as model_name, + CASE + WHEN kim.id IS NOT NULL THEN 'known' + ELSE 'custom' + END as model_state, + COUNT(*) FILTER (WHERE igs.finished >= NOW() - INTERVAL '1 day') as day_images, + COUNT(*) FILTER (WHERE igs.finished >= NOW() - INTERVAL '30 days') as month_images, + COUNT(*) as total_images + FROM + image_gen_stats as igs + LEFT JOIN known_image_models as kim ON igs.model = kim.name + GROUP BY + igs.model, kim.id + ) + INSERT INTO compiled_image_gen_stats_models (created, model_id, model_name, model_state, day_images, month_images, total_images) + SELECT + NOW(), + model_id, + model_name, + model_state, + day_images, + month_images, + total_images + FROM + model_stats; +END; $$; diff --git a/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql b/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql new file mode 100644 index 00000000..35707e3c --- /dev/null +++ b/sql_statements/stored_procedures/compile_imagegen_stats_totals.sql @@ -0,0 +1,39 @@ +CREATE OR REPLACE PROCEDURE compile_imagegen_stats_totals() +LANGUAGE plpgsql +AS $$ +DECLARE + count_minute INTEGER; + count_hour INTEGER; + count_day INTEGER; + count_month INTEGER; + count_total INTEGER; + ps_minute INTEGER; + ps_hour INTEGER; + ps_day INTEGER; + ps_month INTEGER; + ps_total BIGINT; +BEGIN + -- Calculate image counts + SELECT COUNT(*) INTO count_minute FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 minute'; + SELECT COUNT(*) INTO count_hour FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 hour'; + SELECT COUNT(*) INTO count_day FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 day'; + SELECT COUNT(*) INTO count_month FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '30 days'; + SELECT COUNT(*) INTO count_total FROM image_gen_stats; + + -- Calculate pixel sums + SELECT COALESCE(SUM(width * height * steps), 0) INTO ps_minute FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 minute'; + SELECT COALESCE(SUM(width * height * steps), 0) INTO ps_hour FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 hour'; + SELECT COALESCE(SUM(width * height * steps), 0) INTO ps_day FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '1 day'; + SELECT COALESCE(SUM(width * height * steps), 0) INTO ps_month FROM image_gen_stats WHERE finished >= NOW() - INTERVAL '30 days'; + SELECT COALESCE(SUM(width * height * steps), 0) INTO ps_total FROM image_gen_stats; + + -- Insert compiled statistics into compiled_image_gen_stats_totals + INSERT INTO compiled_image_gen_stats_totals ( + created, minute_images, minute_pixels, hour_images, hour_pixels, + day_images, day_pixels, month_images, month_pixels, total_images, total_pixels + ) VALUES ( + NOW(), count_minute, ps_minute, count_hour, ps_hour, + count_day, ps_day, count_month, ps_month, count_total, ps_total + ); +END; +$$; diff --git a/sql_statements/stored_procedures/compile_textgen_stats_models.sql b/sql_statements/stored_procedures/compile_textgen_stats_models.sql new file mode 100644 index 00000000..c51c9c5f --- /dev/null +++ b/sql_statements/stored_procedures/compile_textgen_stats_models.sql @@ -0,0 +1,26 @@ +CREATE OR REPLACE PROCEDURE compile_textgen_stats_models() +LANGUAGE plpgsql +AS $$ +BEGIN + WITH model_stats AS ( + SELECT + tgs.model as model_name, + COUNT(*) FILTER (WHERE tgs.finished >= NOW() - INTERVAL '1 day') as day_requests, + COUNT(*) FILTER (WHERE tgs.finished >= NOW() - INTERVAL '30 days') as month_requests, + COUNT(*) as total_requests + FROM + text_gen_stats as tgs + GROUP BY + tgs.model + ) + INSERT INTO compiled_text_gen_stats_models (created, model, day_requests, month_requests, total_requests) + SELECT + NOW(), + model_name, + day_requests, + month_requests, + total_requests + FROM + model_stats; + COMMIT; +END; $$; diff --git a/sql_statements/stored_procedures/compile_textgen_stats_totals.sql b/sql_statements/stored_procedures/compile_textgen_stats_totals.sql new file mode 100644 index 00000000..501189bb --- /dev/null +++ b/sql_statements/stored_procedures/compile_textgen_stats_totals.sql @@ -0,0 +1,39 @@ +CREATE OR REPLACE PROCEDURE compile_textgen_stats_totals() +LANGUAGE plpgsql +AS $$ +DECLARE + count_minute INTEGER; + count_hour INTEGER; + count_day INTEGER; + count_month INTEGER; + count_total INTEGER; + tokens_minute INTEGER; + tokens_hour INTEGER; + tokens_day INTEGER; + tokens_month INTEGER; + tokens_total BIGINT; +BEGIN + -- Calculate request counts + SELECT COUNT(*) INTO count_minute FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 minute'; + SELECT COUNT(*) INTO count_hour FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 hour'; + SELECT COUNT(*) INTO count_day FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 day'; + SELECT COUNT(*) INTO count_month FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '30 days'; + SELECT COUNT(*) INTO count_total FROM text_gen_stats; + + -- Calculate token sums + SELECT COALESCE(SUM(max_length), 0) INTO tokens_minute FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 minute'; + SELECT COALESCE(SUM(max_length), 0) INTO tokens_hour FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 hour'; + SELECT COALESCE(SUM(max_length), 0) INTO tokens_day FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '1 day'; + SELECT COALESCE(SUM(max_length), 0) INTO tokens_month FROM text_gen_stats WHERE finished >= NOW() - INTERVAL '30 days'; + SELECT COALESCE(SUM(max_length), 0) INTO tokens_total FROM text_gen_stats; + + -- Insert compiled statistics into compiled_text_gen_stats_totals + INSERT INTO compiled_text_gen_stats_totals ( + created, minute_requests, minute_tokens, hour_requests, hour_tokens, + day_requests, day_tokens, month_requests, month_tokens, total_requests, total_tokens + ) VALUES ( + NOW(), count_minute, tokens_minute, count_hour, tokens_hour, + count_day, tokens_day, count_month, tokens_month, count_total, tokens_total + ); +END; +$$; diff --git a/sql_statements/stored_procedures/cron_jobs/cron_stats.sql b/sql_statements/stored_procedures/cron_jobs/cron_stats.sql new file mode 100644 index 00000000..e0a37b95 --- /dev/null +++ b/sql_statements/stored_procedures/cron_jobs/cron_stats.sql @@ -0,0 +1,4 @@ +CALL schedule_cron_job('0 1 1-31 * *', 'compile_imagegen_stats_models'); +CALL schedule_cron_job('0-59 * * * *', 'compile_imagegen_stats_totals'); +CALL schedule_cron_job('0 1 1-31 * *', 'compile_textgen_stats_models'); +CALL schedule_cron_job('0-59 * * * *', 'compile_textgen_stats_totals');