Skip to content

Commit

Permalink
chore: count tokens before hitting OpenAI (#20621)
Browse files Browse the repository at this point in the history
* chore: count tokens before hitting OpenAI

* log the offending input

---------

Co-authored-by: Paul D'Ambra <[email protected]>
  • Loading branch information
daibhin and pauldambra authored Feb 28, 2024
1 parent 5e89d91 commit a9496ea
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
32 changes: 28 additions & 4 deletions ee/session_recordings/ai/generate_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.conf import settings
from openai import OpenAI
import tiktoken

from typing import Dict, Any, List

Expand All @@ -26,6 +27,11 @@
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20],
)

RECORDING_EMBEDDING_TOKEN_COUNT = Histogram(
"posthog_session_recordings_recording_embedding_token_count",
"Token count for individual recordings generated during embedding",
)

SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS = Counter(
"posthog_session_recordings_skipped_when_generating_embeddings",
"Number of sessions skipped when generating embeddings",
Expand Down Expand Up @@ -53,9 +59,13 @@

logger = get_logger(__name__)

# tiktoken.encoding_for_model(model_name) specifies encoder
# model_name = "text-embedding-3-small" for this usecase
encoding = tiktoken.get_encoding("cl100k_base")

BATCH_FLUSH_SIZE = settings.REPLAY_EMBEDDINGS_BATCH_SIZE
MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS
MAX_TOKENS_FOR_MODEL = 8191


def fetch_recordings_without_embeddings(team: Team | int, offset=0) -> List[str]:
Expand Down Expand Up @@ -165,7 +175,7 @@ def flush_embeddings_to_clickhouse(embeddings: List[Dict[str, Any]]) -> None:


def generate_recording_embeddings(session_id: str, team: Team | int) -> List[float] | None:
logger.error(f"generating embedding for session", flow="embeddings", session_id=session_id)
logger.info(f"generating embedding for session", flow="embeddings", session_id=session_id)
if isinstance(team, int):
team = Team.objects.get(id=team)

Expand Down Expand Up @@ -200,7 +210,7 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo
)
)

logger.error(f"collapsed events for session", flow="embeddings", session_id=session_id)
logger.info(f"collapsed events for session", flow="embeddings", session_id=session_id)

processed_sessions_index = processed_sessions.column_index("event")
current_url_index = processed_sessions.column_index("$current_url")
Expand All @@ -219,7 +229,16 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo
)
)

logger.error(f"generating embedding input for session", flow="embeddings", session_id=session_id)
logger.info(f"generating embedding input for session", flow="embeddings", session_id=session_id)

token_count = num_tokens_for_input(input)
RECORDING_EMBEDDING_TOKEN_COUNT.observe(token_count)
if token_count > MAX_TOKENS_FOR_MODEL:
logger.error(
f"embedding input exceeds max token count for model", flow="embeddings", session_id=session_id, input=input
)
SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS.inc()
return None

embeddings = (
client.embeddings.create(
Expand All @@ -230,11 +249,16 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo
.embedding
)

logger.error(f"generated embedding input for session", flow="embeddings", session_id=session_id)
logger.info(f"generated embedding input for session", flow="embeddings", session_id=session_id)

return embeddings


def num_tokens_for_input(string: str) -> int:
"""Returns the number of tokens in a text string."""
return len(encoding.encode(string))


def compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str:
elements_string = elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain)
return f"{event_name} {current_url} {elements_string}"
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ more-itertools==9.0.0
django-two-factor-auth==1.14.0
phonenumberslite==8.13.6
openai==1.10.0
tiktoken==0.6.0
nh3==0.2.14
hogql-parser==1.0.3
urllib3[secure,socks]==1.26.18
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ referencing==0.31.1
# jsonschema
# jsonschema-path
# jsonschema-specifications
regex==2023.12.25
# via tiktoken
requests==2.31.0
# via
# -r requirements.in
Expand All @@ -543,6 +545,7 @@ requests==2.31.0
# snowflake-connector-python
# social-auth-core
# stripe
# tiktoken
# webdriver-manager
requests-oauthlib==1.3.0
# via
Expand Down Expand Up @@ -630,6 +633,8 @@ tenacity==8.2.3
# via
# celery-redbeat
# dlt
tiktoken==0.6.0
# via -r requirements.in
token-bucket==0.3.0
# via -r requirements.in
tomlkit==0.12.3
Expand Down

0 comments on commit a9496ea

Please sign in to comment.