From a9496eace8a5dd969f8fc7d494e6676cb508fd6e Mon Sep 17 00:00:00 2001 From: David Newell Date: Wed, 28 Feb 2024 23:04:43 +0000 Subject: [PATCH] chore: count tokens before hitting OpenAI (#20621) * chore: count tokens before hitting OpenAI * log the offending input --------- Co-authored-by: Paul D'Ambra --- .../ai/generate_embeddings.py | 32 ++++++++++++++++--- requirements.in | 1 + requirements.txt | 5 +++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/ee/session_recordings/ai/generate_embeddings.py b/ee/session_recordings/ai/generate_embeddings.py index 7dd54de3da95e..2ba83db1f220d 100644 --- a/ee/session_recordings/ai/generate_embeddings.py +++ b/ee/session_recordings/ai/generate_embeddings.py @@ -1,5 +1,6 @@ from django.conf import settings from openai import OpenAI +import tiktoken from typing import Dict, Any, List @@ -26,6 +27,11 @@ buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20], ) +RECORDING_EMBEDDING_TOKEN_COUNT = Histogram( + "posthog_session_recordings_recording_embedding_token_count", + "Token count for individual recordings generated during embedding", +) + SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS = Counter( "posthog_session_recordings_skipped_when_generating_embeddings", "Number of sessions skipped when generating embeddings", @@ -53,9 +59,13 @@ logger = get_logger(__name__) +# tiktoken.encoding_for_model(model_name) specifies encoder +# model_name = "text-embedding-3-small" for this usecase +encoding = tiktoken.get_encoding("cl100k_base") BATCH_FLUSH_SIZE = settings.REPLAY_EMBEDDINGS_BATCH_SIZE MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS +MAX_TOKENS_FOR_MODEL = 8191 def fetch_recordings_without_embeddings(team: Team | int, offset=0) -> List[str]: @@ -165,7 +175,7 @@ def flush_embeddings_to_clickhouse(embeddings: List[Dict[str, Any]]) -> None: def generate_recording_embeddings(session_id: str, team: Team | int) -> List[float] | None: - logger.error(f"generating embedding for session", flow="embeddings", session_id=session_id) + logger.info(f"generating embedding for session", flow="embeddings", session_id=session_id) if isinstance(team, int): team = Team.objects.get(id=team) @@ -200,7 +210,7 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo ) ) - logger.error(f"collapsed events for session", flow="embeddings", session_id=session_id) + logger.info(f"collapsed events for session", flow="embeddings", session_id=session_id) processed_sessions_index = processed_sessions.column_index("event") current_url_index = processed_sessions.column_index("$current_url") @@ -219,7 +229,16 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo ) ) - logger.error(f"generating embedding input for session", flow="embeddings", session_id=session_id) + logger.info(f"generating embedding input for session", flow="embeddings", session_id=session_id) + + token_count = num_tokens_for_input(input) + RECORDING_EMBEDDING_TOKEN_COUNT.observe(token_count) + if token_count > MAX_TOKENS_FOR_MODEL: + logger.error( + f"embedding input exceeds max token count for model", flow="embeddings", session_id=session_id, input=input + ) + SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS.inc() + return None embeddings = ( client.embeddings.create( @@ -230,11 +249,16 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo .embedding ) - logger.error(f"generated embedding input for session", flow="embeddings", session_id=session_id) + logger.info(f"generated embedding input for session", flow="embeddings", session_id=session_id) return embeddings +def num_tokens_for_input(string: str) -> int: + """Returns the number of tokens in a text string.""" + return len(encoding.encode(string)) + + def compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str: elements_string = elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain) return f"{event_name} {current_url} {elements_string}" diff --git a/requirements.in b/requirements.in index 0034e45a30243..8404285cc4114 100644 --- a/requirements.in +++ b/requirements.in @@ -97,6 +97,7 @@ more-itertools==9.0.0 django-two-factor-auth==1.14.0 phonenumberslite==8.13.6 openai==1.10.0 +tiktoken==0.6.0 nh3==0.2.14 hogql-parser==1.0.3 urllib3[secure,socks]==1.26.18 diff --git a/requirements.txt b/requirements.txt index e2ddc0b28f311..97549e8d09fae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -528,6 +528,8 @@ referencing==0.31.1 # jsonschema # jsonschema-path # jsonschema-specifications +regex==2023.12.25 + # via tiktoken requests==2.31.0 # via # -r requirements.in @@ -543,6 +545,7 @@ requests==2.31.0 # snowflake-connector-python # social-auth-core # stripe + # tiktoken # webdriver-manager requests-oauthlib==1.3.0 # via @@ -630,6 +633,8 @@ tenacity==8.2.3 # via # celery-redbeat # dlt +tiktoken==0.6.0 + # via -r requirements.in token-bucket==0.3.0 # via -r requirements.in tomlkit==0.12.3