From 2158178a02d89a44100150ee4d249a14f876f133 Mon Sep 17 00:00:00 2001 From: Paul D'Ambra Date: Mon, 12 Feb 2024 22:11:42 +0000 Subject: [PATCH] start wiring up Celery task --- .../ai/generate_embeddings.py | 5 +++- ee/tasks/replay_summaries.py | 25 ++++++++----------- posthog/tasks/tasks.py | 4 +-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ee/session_recordings/ai/generate_embeddings.py b/ee/session_recordings/ai/generate_embeddings.py index cfcc6d2be3120..5ff142abbd21c 100644 --- a/ee/session_recordings/ai/generate_embeddings.py +++ b/ee/session_recordings/ai/generate_embeddings.py @@ -121,7 +121,10 @@ def flush_embeddings_to_clickhouse(embeddings: List[Dict[str, Any]]) -> None: sync_execute("INSERT INTO session_replay_embeddings (session_id, team_id, embeddings) VALUES", embeddings) -def generate_recording_embeddings(session_id: str, team: Team) -> List[float] | None: +def generate_recording_embeddings(session_id: str, team: Team | int) -> List[float] | None: + if isinstance(team, int): + team = Team.objects.get(pk=team) + client = OpenAI() session_metadata = SessionReplayEvents().get_metadata(session_id=str(session_id), team=team) diff --git a/ee/tasks/replay_summaries.py b/ee/tasks/replay_summaries.py index 93102698a3e4b..d4c10bf10f2a6 100644 --- a/ee/tasks/replay_summaries.py +++ b/ee/tasks/replay_summaries.py @@ -1,27 +1,24 @@ import structlog from celery import shared_task -# from ee.session_recordings.ai.generate_embeddings import ( -# generate_recording_embedding, -# fetch_recordings_without_embeddings, -# ) +from ee.session_recordings.ai.generate_embeddings import ( + fetch_recordings_without_embeddings, + generate_recording_embeddings, +) from posthog.tasks.utils import CeleryQueue logger = structlog.get_logger(__name__) -# just so we can merge into another PR @shared_task(ignore_result=True, queue=CeleryQueue.SESSION_REPLAY_EMBEDDINGS.value) def embed_single_recording(session_id: str, team_id: int) -> None: - # generate_recording_embedding(session_id, team_id) - pass + generate_recording_embeddings(session_id, team_id) -# just so we can merge into another PR @shared_task(ignore_result=True) -def generate_recording_embeddings() -> None: - # recordings = fetch_recordings_without_embeddings() - # for recording in []: # recordings: - # # push each embedding task to a separate queue - # embed_single_recording.delay(recording.session_id, recording.team_id) - pass +def generate_recordings_embeddings_batch() -> None: + for recording in fetch_recordings_without_embeddings(): + # push each embedding task to a separate queue + # TODO really we should be doing scatter and gather here + # so we can do one CH update at the end of a batch + embed_single_recording.delay(recording.session_id, recording.team_id) diff --git a/posthog/tasks/tasks.py b/posthog/tasks/tasks.py index 6dc5839cad93a..c32cbab2b31f8 100644 --- a/posthog/tasks/tasks.py +++ b/posthog/tasks/tasks.py @@ -723,8 +723,8 @@ def check_data_import_row_limits() -> None: @shared_task(ignore_result=True) def calculate_replay_embeddings() -> None: try: - from ee.tasks.replay_summaries import generate_recording_embeddings + from ee.tasks.replay_summaries import generate_recordings_embeddings_batch - generate_recording_embeddings() + generate_recordings_embeddings_batch() except ImportError: pass