diff --git a/bots/migrations/0060_conversation_reset_at.py b/bots/migrations/0060_conversation_reset_at.py new file mode 100644 index 000000000..10cd847b6 --- /dev/null +++ b/bots/migrations/0060_conversation_reset_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-20 16:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0059_savedrun_is_api_call'), + ] + + operations = [ + migrations.AddField( + model_name='conversation', + name='reset_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/bots/models.py b/bots/models.py index 3a440ab73..534ccdd7a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -826,6 +826,7 @@ class Conversation(models.Model): ) created_at = models.DateTimeField(auto_now_add=True) + reset_at = models.DateTimeField(null=True, blank=True, default=None) objects = ConversationQuerySet.as_manager() @@ -935,7 +936,9 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, - "Run Time": message.saved_run.run_time if message.saved_run else 0, # user messages have no run/run_time + "Run Time": ( + message.saved_run.run_time if message.saved_run else 0 + ), # user messages have no run/run_time } rows.append(row) df = pd.DataFrame.from_records( @@ -977,7 +980,11 @@ def to_df_analysis_format( ) return df - def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: + def as_llm_context( + self, limit: int = 50, reset_at: datetime.datetime = None + ) -> list["ConversationEntry"]: + if reset_at: + self = self.filter(created_at__gt=reset_at) msgs = self.order_by("-created_at").prefetch_related("attachments")[:limit] entries = [None] * len(msgs) for i, msg in enumerate(reversed(msgs)): diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 2b27ec0b6..7075e3b2d 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -254,8 +254,8 @@ def _on_msg(bot: BotInterface): return # handle reset keyword if input_text.lower() == RESET_KEYWORD: - # clear saved messages - bot.convo.messages.all().delete() + # record the reset time so we don't send context + bot.convo.reset_at = timezone.now() # reset convo state bot.convo.state = ConvoState.INITIAL bot.convo.save() @@ -317,8 +317,8 @@ def _process_and_send_msg( recieved_time: datetime, speech_run: str | None, ): - # get latest messages for context (upto 100) - saved_msgs = bot.convo.messages.all().as_llm_context() + # get latest messages for context + saved_msgs = bot.convo.messages.all().as_llm_context(reset_at=bot.convo.reset_at) # # mock testing # result = _mock_api_output(input_text)