Skip to content

Commit

Permalink
[ADD] summary request cost tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
maschlr committed Sep 7, 2024
1 parent 69c7b0b commit 64b8b69
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add openai request data to summary
Revision ID: 8186365a22ab
Revises: 0712891cc70a
Create Date: 2024-09-07 09:58:48.662217
"""
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "8186365a22ab"
down_revision = "0712891cc70a"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("summary", sa.Column("openai_id", sa.String(), nullable=True))
op.add_column("summary", sa.Column("openai_model", sa.String(), nullable=True))
op.add_column("summary", sa.Column("completion_tokens", sa.Integer(), nullable=True))
op.add_column("summary", sa.Column("prompt_tokens", sa.Integer(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("summary", "prompt_tokens")
op.drop_column("summary", "completion_tokens")
op.drop_column("summary", "openai_model")
op.drop_column("summary", "openai_id")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Add total_seconds to Transcript model
Revision ID: aad719a7e141
Revises: 8186365a22ab
Create Date: 2024-09-07 10:14:33.233855
"""
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "aad719a7e141"
down_revision = "8186365a22ab"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("transcript", sa.Column("total_seconds", sa.Integer(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "total_seconds")
# ### end Alembic commands ###
6 changes: 4 additions & 2 deletions summaree_bot/bot/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def get_summary_msg(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
session.add(transcript)

summary = _summarize(update, context, transcript)
total_cost = summary.total_cost
bot_msg = _get_summary_message(update, context, summary)
chat = session.get(TelegramChat, update.effective_chat.id)

Expand Down Expand Up @@ -116,7 +117,7 @@ async def get_summary_msg(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
]
bot_msg.reply_markup = InlineKeyboardMarkup([buttons])

return bot_msg
return bot_msg, total_cost


async def download_large_file(chat_id: int, message_id: int, filepath: Path):
Expand Down Expand Up @@ -346,7 +347,7 @@ async def transcribe_and_summarize(update: Update, context: ContextTypes.DEFAULT
tg.create_task(context.bot.send_chat_action(update.effective_chat.id, ChatAction.TYPING))

start_message = start_msg_task.result()
bot_response_msg = bot_response_msg_task.result()
bot_response_msg, total_cost = bot_response_msg_task.result()

try:
text = (
Expand All @@ -358,6 +359,7 @@ async def transcribe_and_summarize(update: Update, context: ContextTypes.DEFAULT
f"📝 Summary \#{n_summaries + 1} created by user "
f"{update.effective_user.mention_markdown_v2()} \(in private chat\)"
)
text += escape_markdown(f"\n💰 Cost: $ {total_cost:.6f}" if total_cost else "", version=2)
new_summary_msg = AdminChannelMessage(
text=text,
parse_mode=ParseMode.MARKDOWN_V2,
Expand Down
21 changes: 16 additions & 5 deletions summaree_bot/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import telegram
from openai import AsyncOpenAI, BadRequestError, OpenAI
from openai.types.chat import ParsedChatCompletion
from pydantic import BaseModel
from sqlalchemy import func, select
from telegram.ext import ContextTypes
Expand Down Expand Up @@ -147,6 +148,7 @@ async def transcribe_file(
file_size=voice_or_audio.file_size,
result=whisper_transcription.text,
input_language=transcript_language,
total_seconds=whisper_transcription.total_seconds,
)
session.add(transcript)
session.flush()
Expand All @@ -157,6 +159,7 @@ async def transcribe_file(
class WhisperTranscription:
text: str
language: str
total_seconds: int


async def get_whisper_transcription(file_path: Path):
Expand All @@ -175,8 +178,10 @@ async def get_whisper_transcription(file_path: Path):

languages = []
texts = []
total_seconds = 0
for task in tasks:
transcription_result = task.result()
total_seconds += int(round(transcription_result.model_extra.get("duration", 0), 0))
languages.append(transcription_result.model_extra.get("language"))
texts.append(transcription_result.text)

Expand All @@ -187,7 +192,7 @@ async def get_whisper_transcription(file_path: Path):
_logger.warning("Could not determine language of the transcription")
most_common_language = None

result = WhisperTranscription(text="\n".join(texts), language=most_common_language)
result = WhisperTranscription(text="\n".join(texts), language=most_common_language, total_seconds=total_seconds)

if temp_dir is not None:
temp_dir.cleanup()
Expand Down Expand Up @@ -224,7 +229,9 @@ def _summarize(update: telegram.Update, context: DbSessionContext, transcript: T
return transcript.summary

created_at = dt.datetime.now(dt.UTC)
summary_response = get_openai_chatcompletion(transcript.result)
openai_response: ParsedChatCompletion = get_openai_chatcompletion(transcript.result)
[choice] = openai_response.choices
summary_response: SummaryResponse = choice.message.parsed

if transcript.input_language is None or transcript.input_language.ietf_tag != summary_response.ietf_language_tag:
language_stmt = select(Language).where(Language.ietf_tag == summary_response.ietf_language_tag)
Expand All @@ -240,6 +247,10 @@ def _summarize(update: telegram.Update, context: DbSessionContext, transcript: T
topics=[Topic(text=text, order=i) for i, text in enumerate(summary_response.topics, start=1)],
tg_user_id=update.effective_user.id,
tg_chat_id=update.effective_chat.id,
openai_id=openai_response.id,
openai_model=openai_response.model,
completion_tokens=openai_response.usage.completion_tokens,
prompt_tokens=openai_response.usage.prompt_tokens,
)
transcript.reaction_emoji = summary_response.emoji
transcript.hashtags = summary_response.hashtags
Expand Down Expand Up @@ -359,7 +370,7 @@ class SummaryResponse(BaseModel):
hashtags: list[str]


def get_openai_chatcompletion(transcript: str) -> SummaryResponse:
def get_openai_chatcompletion(transcript: str) -> ParsedChatCompletion:
openai_model = os.getenv("OPENAI_MODEL_ID")
if openai_model is None:
raise ValueError("OPENAI_MODEL_ID environment variable not set")
Expand All @@ -382,7 +393,7 @@ def get_openai_chatcompletion(transcript: str) -> SummaryResponse:
{"role": "user", "content": transcript},
],
response_format=SummaryResponse,
n=1,
)

[choice] = summary_result.choices
return choice.message.parsed
return summary_result
28 changes: 28 additions & 0 deletions summaree_bot/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import json
import os
import re
import secrets
from datetime import datetime
from typing import List, Optional
Expand Down Expand Up @@ -224,6 +225,7 @@ class Transcript(Base):
mime_type: Mapped[str]
file_size: Mapped[int]
result: Mapped[str]
total_seconds: Mapped[Optional[int]]

finished_at: Mapped[Optional[datetime]]

Expand Down Expand Up @@ -253,6 +255,12 @@ class Summary(Base):
tg_chat_id: Mapped[Optional[BigInteger]] = mapped_column(ForeignKey("telegram_chat.id"))
tg_chat: Mapped["TelegramChat"] = relationship(back_populates="summaries")

# openai Data to track usage/costs
openai_id: Mapped[Optional[str]]
openai_model: Mapped[Optional[str]]
completion_tokens: Mapped[Optional[int]]
prompt_tokens: Mapped[Optional[int]]

messages: Mapped[List["BotMessage"]] = relationship(back_populates="summary")
topics: Mapped[List["Topic"]] = relationship(back_populates="summary")

Expand All @@ -271,6 +279,26 @@ def get_usage_stats(cls, session: Session) -> dict:

return query.all()

@property
def total_cost(self) -> Optional[float]:
"""
Calculate the total cost of the summary.
"""
# gpt-4o: $0.015 per 1M tokens
# gpt-4: $0.015 per 1M tokens
# gpt-3.5-turbo: $0.0015 per 1M tokens
match = re.match(r"gpt-4o-mini.*?", self.openai_model)
if not match:
# raise NotImplementedError(f"Cost for model {self.openai_model} not implemented")
return None

total_cost = (
self.completion_tokens / 1e6 * 0.6
+ self.prompt_tokens / 1e6 * 0.15
+ self.transcript.total_seconds / 60 * 0.006
)
return total_cost


class Topic(Base):
__tablename__ = "topic"
Expand Down

0 comments on commit 64b8b69

Please sign in to comment.