Skip to content

Commit

Permalink
feat: embed errors (#20752)
Browse files Browse the repository at this point in the history
  • Loading branch information
daibhin authored Mar 8, 2024
1 parent a2109a2 commit ea340fc
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 298 deletions.
110 changes: 110 additions & 0 deletions ee/session_recordings/ai/embeddings_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from django.conf import settings

from typing import List

from posthog.models import Team
from posthog.clickhouse.client import sync_execute

BATCH_FLUSH_SIZE = settings.REPLAY_EMBEDDINGS_BATCH_SIZE
MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS


def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[str]:
query = """
WITH embedded_sessions AS (
SELECT
session_id
FROM
session_replay_embeddings
WHERE
team_id = %(team_id)s
-- don't load all data for all time
AND generation_timestamp > now() - INTERVAL 7 DAY
AND source_type = 'error'
)
SELECT log_source_id, message
FROM log_entries
PREWHERE
team_id = %(team_id)s
AND level = 'error'
AND log_source = 'session_replay'
AND timestamp <= now()
AND timestamp >= now() - INTERVAL 7 DAY
AND log_source_id NOT IN embedded_sessions
LIMIT %(batch_flush_size)s
-- when running locally the offset is used for paging
-- when running in celery the offset is not used
OFFSET %(offset)s
"""

return sync_execute(
query,
{
"team_id": team_id,
"batch_flush_size": BATCH_FLUSH_SIZE,
"offset": offset,
},
)


def fetch_recordings_without_embeddings(team_id: int, offset=0) -> List[str]:
team = Team.objects.get(id=team_id)

query = """
WITH embedding_ids AS
(
SELECT
session_id
FROM
session_replay_embeddings
WHERE
team_id = %(team_id)s
-- don't load all data for all time
AND generation_timestamp > now() - INTERVAL 7 DAY
),
replay_with_events AS
(
SELECT
distinct $session_id
FROM
events
WHERE
team_id = %(team_id)s
-- don't load all data for all time
AND timestamp > now() - INTERVAL 7 DAY
AND timestamp < now()
AND $session_id IS NOT NULL AND $session_id != ''
)
SELECT session_id
FROM
session_replay_events
WHERE
session_id NOT IN embedding_ids
AND team_id = %(team_id)s
-- must be a completed session
AND min_first_timestamp < now() - INTERVAL 1 DAY
-- let's not load all data for all time
-- will definitely need to do something about this length of time
AND min_first_timestamp > now() - INTERVAL 7 DAY
AND session_id IN replay_with_events
GROUP BY session_id
HAVING dateDiff('second', min(min_first_timestamp), max(max_last_timestamp)) > %(min_duration_include_seconds)s
ORDER BY rand()
LIMIT %(batch_flush_size)s
-- when running locally the offset is used for paging
-- when running in celery the offset is not used
OFFSET %(offset)s
"""

return [
x[0]
for x in sync_execute(
query,
{
"team_id": team.pk,
"batch_flush_size": BATCH_FLUSH_SIZE,
"offset": offset,
"min_duration_include_seconds": MIN_DURATION_INCLUDE_SECONDS,
},
)
]
284 changes: 284 additions & 0 deletions ee/session_recordings/ai/embeddings_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import json
import tiktoken
import datetime
import pytz

from typing import Dict, Any, List, Tuple

from abc import ABC, abstractmethod
from prometheus_client import Histogram, Counter
from structlog import get_logger
from openai import OpenAI

from posthog.models import Team
from posthog.clickhouse.client import sync_execute

from posthog.session_recordings.queries.session_replay_events import SessionReplayEvents
from ee.session_recordings.ai.utils import (
SessionSummaryPromptData,
reduce_elements_chain,
simplify_window_id,
format_dates,
collapse_sequence_of_events,
only_pageview_urls,
)

# tiktoken.encoding_for_model(model_name) specifies encoder
# model_name = "text-embedding-3-small" for this usecase
encoding = tiktoken.get_encoding("cl100k_base")

MAX_TOKENS_FOR_MODEL = 8191

RECORDING_EMBEDDING_TOKEN_COUNT = Histogram(
"posthog_session_recordings_recording_embedding_token_count",
"Token count for individual recordings generated during embedding",
buckets=[0, 100, 500, 1000, 2000, 3000, 4000, 5000, 6000, 8000, 10000],
)

GENERATE_RECORDING_EMBEDDING_TIMING = Histogram(
"posthog_session_recordings_generate_recording_embedding",
"Time spent generating recording embeddings for a single session",
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20],
)

SESSION_EMBEDDINGS_GENERATED = Counter(
"posthog_session_recordings_embeddings_generated",
"Number of session embeddings generated",
)

SESSION_EMBEDDINGS_FAILED = Counter(
"posthog_session_recordings_embeddings_failed",
"Instance of an embedding request to open AI (and its surrounding work) failing and being swallowed",
)

SESSION_EMBEDDINGS_FATAL_FAILED = Counter(
"posthog_session_recordings_embeddings_fatal_failed",
"Instance of the embeddings task failing and raising an exception",
)

SESSION_EMBEDDINGS_WRITTEN_TO_CLICKHOUSE = Counter(
"posthog_session_recordings_embeddings_written_to_clickhouse",
"Number of session embeddings written to Clickhouse",
)

SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS = Counter(
"posthog_session_recordings_skipped_when_generating_embeddings",
"Number of sessions skipped when generating embeddings",
)

SESSION_EMBEDDINGS_FAILED_TO_CLICKHOUSE = Counter(
"posthog_session_recordings_embeddings_failed_to_clickhouse",
"Number of session embeddings failed to Clickhouse",
)


logger = get_logger(__name__)


class EmbeddingPreparation(ABC):
source_type: str

@staticmethod
@abstractmethod
def prepare(item, team) -> Tuple[str, str]:
raise NotImplementedError()


class SessionEmbeddingsRunner(ABC):
team: Team
openai_client: Any

def __init__(self, team: Team):
self.team = team
self.client = OpenAI()

def run(self, items: List[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None:
source_type = embeddings_preparation.source_type

try:
batched_embeddings = []

for item in items:
try:
logger.info(
f"generating embedding input for item",
flow="embeddings",
item=json.dumps(item),
source_type=source_type,
)

result = embeddings_preparation.prepare(item, self.team)

if result:
session_id, input = result

logger.info(
f"generating embedding for item",
flow="embeddings",
session_id=session_id,
source_type=source_type,
)

with GENERATE_RECORDING_EMBEDDING_TIMING.labels(source_type=source_type).time():
embeddings = self._embed_input(input, source_type=source_type)

logger.info(
f"generated embedding for item",
flow="embeddings",
session_id=session_id,
source_type=source_type,
)

if embeddings:
SESSION_EMBEDDINGS_GENERATED.labels(source_type=source_type).inc()
batched_embeddings.append(
{
"team_id": self.team.pk,
"session_id": session_id,
"embeddings": embeddings,
"source_type": source_type,
}
)
# we don't want to fail the whole batch if only a single recording fails
except Exception as e:
SESSION_EMBEDDINGS_FAILED.labels(source_type=source_type).inc()
logger.error(
f"embed individual item error",
flow="embeddings",
error=e,
source_type=source_type,
)
# so we swallow errors here

if len(batched_embeddings) > 0:
self._flush_embeddings_to_clickhouse(embeddings=batched_embeddings, source_type=source_type)
except Exception as e:
# but we don't swallow errors within the wider task itself
# if something is failing here then we're most likely having trouble with ClickHouse
SESSION_EMBEDDINGS_FATAL_FAILED.labels(source_type=source_type).inc()
logger.error(f"embed items fatal error", flow="embeddings", error=e, source_type=source_type)
raise e

def _embed_input(self, input: str, source_type: str):
token_count = self._num_tokens_for_input(input)
RECORDING_EMBEDDING_TOKEN_COUNT.labels(source_type=source_type).observe(token_count)
if token_count > MAX_TOKENS_FOR_MODEL:
logger.error(
f"embedding input exceeds max token count for model",
flow="embeddings",
input=json.dumps(input),
source_type=source_type,
)
SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS.labels(
source_type=source_type, reason="token_count_too_high"
).inc()
return None

return (
self.openai_client.embeddings.create(
input=input,
model="text-embedding-3-small",
)
.data[0]
.embedding
)

def _num_tokens_for_input(self, string: str) -> int:
"""Returns the number of tokens in a text string."""
return len(encoding.encode(string))

def _flush_embeddings_to_clickhouse(self, embeddings: List[Dict[str, Any]], source_type: str) -> None:
try:
sync_execute(
"INSERT INTO session_replay_embeddings (session_id, team_id, embeddings, source_type) VALUES",
embeddings,
)
SESSION_EMBEDDINGS_WRITTEN_TO_CLICKHOUSE.labels(source_type=source_type).inc(len(embeddings))
except Exception as e:
logger.error(f"flush embeddings error", flow="embeddings", error=e, source_type=source_type)
SESSION_EMBEDDINGS_FAILED_TO_CLICKHOUSE.labels(source_type=source_type).inc(len(embeddings))
raise e


class ErrorEmbeddingsPreparation(EmbeddingPreparation):
source_type = "error"

@staticmethod
def prepare(item: Tuple[str, str], _):
session_id = item[0]
error_message = item[1]
return session_id, error_message


class SessionEventsEmbeddingsPreparation(EmbeddingPreparation):
source_type = "session"

@staticmethod
def prepare(session_id: str, team: Team):
eight_days_ago = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=8)
session_metadata = SessionReplayEvents().get_metadata(
session_id=str(session_id), team=team, recording_start_time=eight_days_ago
)
if not session_metadata:
logger.error(f"no session metadata found for session", flow="embeddings", session_id=session_id)
SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS.labels(
source_type=SessionEventsEmbeddingsPreparation.source_type, reason="metadata_missing"
).inc()
return None

session_events = SessionReplayEvents().get_events(
session_id=str(session_id),
team=team,
metadata=session_metadata,
events_to_ignore=[
"$feature_flag_called",
],
)

if not session_events or not session_events[0] or not session_events[1]:
logger.error(f"no events found for session", flow="embeddings", session_id=session_id)
SESSION_SKIPPED_WHEN_GENERATING_EMBEDDINGS.labels(
source_type=SessionEventsEmbeddingsPreparation.source_type, reason="events_missing"
).inc()
return None

processed_sessions = collapse_sequence_of_events(
only_pageview_urls(
format_dates(
reduce_elements_chain(
simplify_window_id(
SessionSummaryPromptData(columns=session_events[0], results=session_events[1])
)
),
start=datetime.datetime(1970, 1, 1, tzinfo=pytz.UTC), # epoch timestamp
)
)
)

logger.info(f"collapsed events for session", flow="embeddings", session_id=session_id)

processed_sessions_index = processed_sessions.column_index("event")
current_url_index = processed_sessions.column_index("$current_url")
elements_chain_index = processed_sessions.column_index("elements_chain")

input = (
str(session_metadata)
+ "\n"
+ "\n".join(
SessionEventsEmbeddingsPreparation._compact_result(
event_name=result[processed_sessions_index] if processed_sessions_index is not None else "",
current_url=result[current_url_index] if current_url_index is not None else "",
elements_chain=result[elements_chain_index] if elements_chain_index is not None else "",
)
for result in processed_sessions.results
)
)

return session_id, input

@staticmethod
def _compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str:
elements_string = (
elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain)
)
return f"{event_name} {current_url} {elements_string}"
Loading

0 comments on commit ea340fc

Please sign in to comment.