Skip to content

Commit

Permalink
feat: caching improvements for stats endpoints (#420)
Browse files Browse the repository at this point in the history
* fix: index text/img stat `finished` field in db

* feat: pg_cron based stats tables

* fix: working `get_compiled_imagegen_stats_models`

* fix: correct return schema for image stats models

* fix: models stats cron job now is once per day

* fix: use 3.9 typehint syntax

* fix: correct text models stats schema

* fix: use correct table name for text gen totals

* fix: use correct table name for image gen totals

* fix: db `schedule_cron_job` now updates too

If the cron string (`p_schedule`) is different than the currently scheduled cron job, `schedule_cron_job` now makes sure that the schedule is updated.

* style: fix
  • Loading branch information
tazlin authored Jun 10, 2024
1 parent bc264b8 commit 4072a48
Show file tree
Hide file tree
Showing 17 changed files with 650 additions and 172 deletions.
8 changes: 4 additions & 4 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from horde.apis.v2.base import GenerateTemplate, JobPopTemplate, JobSubmitTemplate, api
from horde.classes.base import settings
from horde.classes.kobold.genstats import (
compile_textgen_stats_models,
compile_textgen_stats_totals,
get_compiled_textgen_stats_models,
get_compiled_textgen_stats_totals,
)
from horde.classes.kobold.waiting_prompt import TextWaitingPrompt
from horde.classes.kobold.worker import TextWorker
Expand Down Expand Up @@ -356,7 +356,7 @@ def get(self):
"""Details how many texts have been generated in the past minux,hour,day,month and total
Also shows the amount of pixelsteps for the same timeframe.
"""
return compile_textgen_stats_totals(), 200
return get_compiled_textgen_stats_totals(), 200


class TextHordeStatsModels(Resource):
Expand All @@ -380,7 +380,7 @@ class TextHordeStatsModels(Resource):
)
def get(self):
"""Details how many texts were generated per model for the past day, month and total"""
return compile_textgen_stats_models(), 200
return get_compiled_textgen_stats_models(), 200


class KoboldKudosTransfer(Resource):
Expand Down
9 changes: 5 additions & 4 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from horde.classes.base import settings
from horde.classes.base.user import User
from horde.classes.stable.genstats import (
compile_imagegen_stats_models,
compile_imagegen_stats_totals,
get_compiled_imagegen_stats_models,
get_compiled_imagegen_stats_totals,
)
from horde.classes.stable.interrogation import Interrogation
from horde.classes.stable.interrogation_worker import InterrogationWorker
Expand Down Expand Up @@ -573,6 +573,7 @@ def post(self):
if "blacklist" in post_ret.get("skipped", {}):
db_skipped["blacklist"] = post_ret["skipped"]["blacklist"]
post_ret["skipped"] = db_skipped

return post_ret, retcode

def check_in(self):
Expand Down Expand Up @@ -1272,7 +1273,7 @@ def get(self):
"""Details how many images have been generated in the past minux,hour,day,month and total
Also shows the amount of pixelsteps for the same timeframe.
"""
return compile_imagegen_stats_totals(), 200
return get_compiled_imagegen_stats_totals(), 200


class ImageHordeStatsModels(Resource):
Expand Down Expand Up @@ -1312,4 +1313,4 @@ def get(self):
self.args = self.get_parser.parse_args()
if self.args.model_state not in ["known", "custom", "all"]:
raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']")
return compile_imagegen_stats_models(self.args.model_state), 200
return get_compiled_imagegen_stats_models(self.args.model_state), 200
29 changes: 28 additions & 1 deletion horde/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import horde.classes.base.stats # noqa 401
from horde.argparser import args
from horde.classes.base.detection import Filter # noqa 401
Expand All @@ -10,13 +12,14 @@
from horde.classes.kobold.worker import TextWorker # noqa 401
from horde.classes.stable.interrogation import Interrogation # noqa 401
from horde.classes.stable.interrogation_worker import InterrogationWorker # noqa 401
from horde.classes.stable.known_image_models import KnownImageModel # noqa 401

# Importing for DB creation

# noqa 401
from horde.classes.stable.waiting_prompt import ImageWaitingPrompt # noqa 401
from horde.classes.stable.worker import ImageWorker # noqa 401
from horde.flask import HORDE, db
from horde.logger import logger
from horde.utils import hash_api_key

with HORDE.app_context():
Expand All @@ -29,6 +32,29 @@
# sys.exit()
db.create_all()

sql_statement_dir = Path(__file__).parent.parent.parent / "sql_statements"

# The order of these directories is important. `cron` creates a stored procedure that is
# used by queries in all other `cron_jobs/` directories.
all_dirs_to_run = [
"cron/", # Must be first
"stored_procedures/",
"stored_procedures/cron_jobs/",
]

all_dirs_to_run = [sql_statement_dir / dir for dir in all_dirs_to_run]

with logger.catch(reraise=True):
for dir in all_dirs_to_run:
logger.info(f"Running files in {dir}")
for file in dir.iterdir():
if file.suffix == ".sql":
logger.info(f"Running {file}")
with file.open() as f:
db.session.execute(f.read())

db.session.commit()

if args.convert_flag == "roles":
# from horde.conversions import convert_user_roles

Expand Down Expand Up @@ -65,6 +91,7 @@
"TextWaitingPrompt",
"Interrogation",
"InterrogationWorker",
"KnownImageModel",
"User",
"Team",
"ImageWorker",
Expand Down
2 changes: 1 addition & 1 deletion horde/classes/base/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class User(db.Model):
oauth_id = db.Column(db.String(50), unique=True, nullable=False, index=True)
api_key = db.Column(db.String(100), unique=True, nullable=False, index=True)
client_id = db.Column(db.String(50), unique=True, default=generate_client_id, nullable=False)
created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
last_active = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
contact = db.Column(db.String(50), default=None)
admin_comment = db.Column(db.Text, default=None)
Expand Down
162 changes: 79 additions & 83 deletions horde/classes/kobold/genstats.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
from datetime import datetime, timedelta
from datetime import datetime

from sqlalchemy import Enum, func
from sqlalchemy import Enum

from horde.enums import ImageGenState
from horde.flask import db


class TextGenerationStatistic(db.Model):
__tablename__ = "text_gen_stats"
id = db.Column(db.Integer, primary_key=True)
finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow)
# Created comes from the procgen
created = db.Column(db.DateTime(timezone=False), nullable=True)
model = db.Column(db.String(255), index=True, nullable=False)
max_length = db.Column(db.Integer, nullable=False)
max_context_length = db.Column(db.Integer, nullable=False)
softprompt = db.Column(db.Integer, nullable=True)
prompt_length = db.Column(db.Integer, nullable=False)
client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True)


def record_text_statistic(procgen):
state = ImageGenState.OK
# Currently there's no way to record cancelled images, but maybe there will be in the future
Expand All @@ -44,71 +28,83 @@ def record_text_statistic(procgen):
db.session.commit()


def compile_textgen_stats_totals():
count_query = db.session.query(TextGenerationStatistic)
count_minute = count_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1),
).count()
count_hour = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1)).count()
count_day = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1)).count()
count_month = count_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).count()
count_total = count_query.count()
tokens_query = db.session.query(func.sum(TextGenerationStatistic.max_length))
tokens_minute = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1),
).scalar()
tokens_hour = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1),
).scalar()
tokens_day = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1),
).scalar()
tokens_month = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).scalar()
tokens_total = tokens_query.scalar()
stats_dict = {
"minute": {
"requests": count_minute,
"tokens": tokens_minute,
},
"hour": {
"requests": count_hour,
"tokens": tokens_hour,
},
"day": {
"requests": count_day,
"tokens": tokens_day,
},
"month": {
"requests": count_month,
"tokens": tokens_month,
},
"total": {
"requests": count_total,
"tokens": tokens_total,
},
}
class TextGenerationStatistic(db.Model):
__tablename__ = "text_gen_stats"
id = db.Column(db.Integer, primary_key=True)
finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
# Created comes from the procgen
created = db.Column(db.DateTime(timezone=False), nullable=True)
model = db.Column(db.String(255), nullable=False, index=True)
max_length = db.Column(db.Integer, nullable=False)
max_context_length = db.Column(db.Integer, nullable=False)
softprompt = db.Column(db.Integer, nullable=True)
prompt_length = db.Column(db.Integer, nullable=False)
client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True)


class CompiledTextGensStatsTotals(db.Model):
__tablename__ = "compiled_text_gen_stats_totals"
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
minute_requests = db.Column(db.Integer, nullable=False)
minute_tokens = db.Column(db.Integer, nullable=False)
hour_requests = db.Column(db.Integer, nullable=False)
hour_tokens = db.Column(db.Integer, nullable=False)
day_requests = db.Column(db.Integer, nullable=False)
day_tokens = db.Column(db.Integer, nullable=False)
month_requests = db.Column(db.Integer, nullable=False)
month_tokens = db.Column(db.Integer, nullable=False)
total_requests = db.Column(db.Integer, nullable=False)
total_tokens = db.Column(db.BigInteger, nullable=False)


def get_compiled_textgen_stats_totals() -> dict[str, dict[str, int]]:
"""Get the compiled text generation statistics for the minute, hour, day, month, and total periods.
Returns:
dict[str, dict[str, int]]: A dictionary with the period as the key and the requests and tokens as the values.
"""
query = db.session.query(CompiledTextGensStatsTotals).order_by(CompiledTextGensStatsTotals.created.desc()).first()

periods = ["minute", "hour", "day", "month", "total"]
stats_dict = {period: {"requests": 0, "tokens": 0} for period in periods}

if query:
for period in periods:
stats_dict[period]["requests"] = getattr(query, f"{period}_requests")
stats_dict[period]["tokens"] = getattr(query, f"{period}_tokens")

return stats_dict


def compile_textgen_stats_models():
query = db.session.query(TextGenerationStatistic.model, func.count()).group_by(TextGenerationStatistic.model)
ret_dict = {
"total": {model: count for model, count in query.all()},
"day": {
model: count
for model, count in query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1),
).all()
},
"month": {
model: count
for model, count in query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).all()
},
}
return ret_dict
class CompiledTextGenStatsModels(db.Model):
__tablename__ = "compiled_text_gen_stats_models"
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
model = db.Column(db.String(255), nullable=False, index=True)
day_requests = db.Column(db.Integer, nullable=False)
month_requests = db.Column(db.Integer, nullable=False)
total_requests = db.Column(db.Integer, nullable=False)


def get_compiled_textgen_stats_models() -> dict[str, dict[str, int]]:
"""Get the compiled text generation statistics for the day, month, and total periods for each model.
Returns:
dict[str, dict[str, int]]: A dictionary with the model as the key and the requests as the values.
"""

models: tuple[CompiledTextGenStatsModels] = (
db.session.query(CompiledTextGenStatsModels).order_by(CompiledTextGenStatsModels.created.desc()).all()
)

periods = ["day", "month", "total"]
stats = {period: {model.model: 0 for model in models} for period in periods}

for model in models:
for period in periods:
stats[period][model.model] = getattr(model, f"{period}_requests")

return stats
Loading

0 comments on commit 4072a48

Please sign in to comment.