Skip to content

Commit

Permalink
v0
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Sep 19, 2023
1 parent 4dd6522 commit 4b772bd
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 35 deletions.
10 changes: 6 additions & 4 deletions bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from django.utils.timesince import timesince
from furl import furl

from app_users.admin import AppUserAdmin
from app_users.models import AppUser
from bots.admin_links import list_related_html_url, open_in_new_tab, change_obj_url
from bots.models import (
FeedbackComment,
Expand Down Expand Up @@ -88,7 +90,6 @@ class Media:
class BotIntegrationAdmin(admin.ModelAdmin):
search_fields = [
"name",
"billing_account_uid",
"user_language",
"fb_page_id",
"fb_page_name",
Expand All @@ -103,23 +104,23 @@ class BotIntegrationAdmin(admin.ModelAdmin):
"slack_channel_name",
"slack_channel_hook_url",
"slack_access_token",
]
] + ["user__" + field for field in AppUserAdmin.search_fields]
list_display = [
"name",
"get_display_name",
"platform",
"wa_phone_number",
"created_at",
"updated_at",
"billing_account_uid",
"user",
"saved_run",
"analysis_run",
]
list_filter = ["platform"]

form = BotIntegrationAdminForm

autocomplete_fields = ["saved_run", "analysis_run"]
autocomplete_fields = ["saved_run", "analysis_run", "user"]

readonly_fields = [
"fb_page_access_token",
Expand All @@ -140,6 +141,7 @@ class BotIntegrationAdmin(admin.ModelAdmin):
"name",
"saved_run",
"billing_account_uid",
"user",
"user_language",
],
},
Expand Down
89 changes: 89 additions & 0 deletions bots/migrations/0043_botintegration_user_savedrun_user_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Generated by Django 4.2.3 on 2023-09-19 21:52

import django.db.models.deletion
from django.db import migrations, models
from django.db.models import Subquery, OuterRef


def forwards_func(apps, schema_editor):
AppUser = apps.get_model("app_users", "AppUser")
SavedRun = apps.get_model("bots", "SavedRun")
BotIntegration = apps.get_model("bots", "BotIntegration")
db_alias = schema_editor.connection.alias
user_objects = AppUser.objects.using(db_alias)
SavedRun.objects.using(db_alias).filter(uid__isnull=False).update(
user=Subquery(
user_objects.filter(uid=OuterRef("uid")).values("pk")[:1],
),
)
BotIntegration.objects.using(db_alias).all().update(
user=Subquery(
user_objects.filter(uid=OuterRef("billing_account_uid")).values("pk")[:1],
),
)


class Migration(migrations.Migration):
dependencies = [
("app_users", "0006_appuser_disable_safety_checker"),
("bots", "0042_alter_message_platform_msg_id"),
]

operations = [
migrations.AddField(
model_name="botintegration",
name="user",
field=models.ForeignKey(
default=None,
help_text="The gooey account uid where the credits will be deducted from",
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="bot_integrations",
to="app_users.appuser",
),
),
migrations.AddField(
model_name="savedrun",
name="user",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="saved_runs",
to="app_users.appuser",
),
),
migrations.AlterField(
model_name="botintegration",
name="billing_account_uid",
field=models.TextField(
db_index=True, help_text="DEPRECATED: use user field instead"
),
),
migrations.AlterField(
model_name="savedrun",
name="uid",
field=models.CharField(
blank=True,
default=None,
help_text="DEPRECATED: use user filed instead",
max_length=128,
null=True,
),
),
migrations.RunPython(
forwards_func,
migrations.RunPython.noop,
),
migrations.RemoveIndex(
model_name="botintegration",
name="bots_botint_billing_466a86_idx",
),
migrations.AddIndex(
model_name="botintegration",
index=models.Index(
fields=["user", "platform"], name="bots_botint_user_id_2c82dd_idx"
),
),
]
35 changes: 29 additions & 6 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,21 @@ class SavedRun(models.Model):
)
example_id = models.CharField(max_length=128, default=None, null=True, blank=True)
run_id = models.CharField(max_length=128, default=None, null=True, blank=True)
uid = models.CharField(max_length=128, default=None, null=True, blank=True)
uid = models.CharField(
max_length=128,
default=None,
null=True,
blank=True,
help_text="DEPRECATED: use user filed instead",
)
user = models.ForeignKey(
"app_users.AppUser",
on_delete=models.SET_NULL,
null=True,
blank=True,
default=None,
related_name="saved_runs",
)

state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)

Expand Down Expand Up @@ -259,7 +273,7 @@ def _parse_dt(dt) -> datetime.datetime | None:
class BotIntegrationQuerySet(models.QuerySet):
@transaction.atomic()
def reset_fb_pages_for_user(
self, uid: str, fb_pages: list[dict]
self, user: AppUser, fb_pages: list[dict]
) -> list["BotIntegration"]:
saved = []
for fb_page in fb_pages:
Expand All @@ -274,7 +288,7 @@ def reset_fb_pages_for_user(
)
except BotIntegration.DoesNotExist:
bi = BotIntegration(fb_page_id=fb_page_id)
bi.billing_account_uid = uid
bi.user = user
bi.fb_page_name = fb_page["name"]
# bi.fb_user_access_token = user_access_token
bi.fb_page_access_token = fb_page["access_token"]
Expand All @@ -294,7 +308,7 @@ def reset_fb_pages_for_user(
# delete pages that are no longer connected for this user
self.filter(
Q(platform=Platform.FACEBOOK) | Q(platform=Platform.INSTAGRAM),
billing_account_uid=uid,
user=user,
).exclude(
id__in=[bi.id for bi in saved],
).delete()
Expand All @@ -316,9 +330,18 @@ class BotIntegration(models.Model):
help_text="The saved run that the bot is based on",
)
billing_account_uid = models.TextField(
help_text="The gooey account uid where the credits will be deducted from",
help_text="DEPRECATED: use user field instead",
db_index=True,
)
user = models.ForeignKey(
"app_users.AppUser",
on_delete=models.CASCADE,
null=True,
default=None,
related_name="bot_integrations",
help_text="The gooey account uid where the credits will be deducted from",
)

user_language = models.TextField(
default="en",
help_text="The response language (same as user language in video bots)",
Expand Down Expand Up @@ -444,7 +467,7 @@ class Meta:
("slack_channel_id", "slack_team_id"),
]
indexes = [
models.Index(fields=["billing_account_uid", "platform"]),
models.Index(fields=["user", "platform"]),
models.Index(fields=["fb_page_id", "ig_account_id"]),
]

Expand Down
3 changes: 1 addition & 2 deletions bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ def msg_analysis(msg_id: int):
assert analysis_sr, "bot integration must have an analysis run"

# make the api call
billing_account = AppUser.objects.get(uid=bi.billing_account_uid)
variables = dict(
user_msg=msg.get_previous_by_created_at().content,
assistant_msg=msg.content,
bot_script=msg.saved_run.state.get("bot_script", ""),
references=references_as_prompt(msg.saved_run.state.get("references", [])),
)
result, sr = analysis_sr.submit_api_call(
current_user=billing_account,
current_user=bi.user,
request_body=dict(variables=variables),
)

Expand Down
2 changes: 1 addition & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def send_email_on_completion(page: BasePage, sr: SavedRun):
):
return
to_address = (
AppUser.objects.filter(uid=sr.uid).values_list("email", flat=True).first()
AppUser.objects.filter(uid=sr.user.uid).values_list("email", flat=True).first()
)
if not to_address:
return
Expand Down
6 changes: 3 additions & 3 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,14 @@ def recipe_doc_sr(self) -> SavedRun:
return SavedRun.objects.get_or_create(
workflow=self.workflow,
run_id__isnull=True,
uid__isnull=True,
user__isnull=True,
example_id__isnull=True,
)[0]

def run_doc_sr(
self, run_id: str, uid: str, create: bool = False, parent: SavedRun = None
) -> SavedRun:
config = dict(workflow=self.workflow, uid=uid, run_id=run_id)
config = dict(workflow=self.workflow, user__uid=uid, run_id=run_id)
if create:
return SavedRun.objects.get_or_create(
**config, defaults=dict(parent=parent)
Expand Down Expand Up @@ -872,7 +872,7 @@ def _history_tab(self):
run_history = list(
SavedRun.objects.filter(
workflow=self.workflow,
uid=uid,
user=self.request.user,
updated_at__lt=before,
)[:25]
)
Expand Down
54 changes: 42 additions & 12 deletions daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from daras_ai_v2.base import BasePage
from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT
from gooeysite.bg_db_conn import db_middleware
from routers.api import submit_api_call


async def request_json(request: Request):
Expand All @@ -35,7 +36,6 @@ async def request_urlencoded_body(request: Request):
class BotInterface:
input_message: dict
platform: Platform
billing_account_uid: str
page_cls: typing.Type[BasePage] | None
query_params: dict
bot_id: str
Expand Down Expand Up @@ -76,13 +76,12 @@ def _unpack_bot_integration(self):
self.query_params = self.page_cls.clean_query_params(
example_id=bi.saved_run.example_id,
run_id=bi.saved_run.run_id,
uid=bi.saved_run.uid,
uid=bi.saved_run.user.uid,
)
else:
self.page_cls = None
self.query_params = {}

self.billing_account_uid = bi.billing_account_uid
self.language = bi.user_language
self.show_feedback_buttons = bi.show_feedback_buttons

Expand Down Expand Up @@ -160,10 +159,6 @@ def _on_msg(bot: BotInterface):
if bot.input_type != "interactive":
# mark message as read
bot.mark_read()
# get the attached billing account
billing_account_user = AppUser.objects.get_or_create_from_uid(
bot.billing_account_uid
)[0]
# get the user's input
# print("input type:", bot.input_type)
match bot.input_type:
Expand All @@ -173,7 +168,7 @@ def _on_msg(bot: BotInterface):
return
case "audio" | "video":
try:
result = _handle_audio_msg(billing_account_user, bot)
result = _handle_audio_msg(bot)
speech_run = result.get("url")
except HTTPException as e:
traceback.print_exc()
Expand Down Expand Up @@ -349,7 +344,7 @@ def _handle_interactive_msg(bot: BotInterface):
)


def _handle_audio_msg(billing_account_user, bot: BotInterface):
def _handle_audio_msg(bot: BotInterface):
from recipes.asr import AsrPage
from routers.api import call_api

Expand All @@ -374,18 +369,53 @@ def _handle_audio_msg(billing_account_user, bot: BotInterface):
case _:
selected_model = AsrModels.whisper_large_v2.name

result = call_api(
# result = call_api(
# page_cls=AsrPage,
# user=bot.convo.bot_integration.user,
# request_body={
# "documents": [input_audio],
# "selected_model": selected_model,
# "google_translate_target": None,
# "language": language,
# },
# query_params={},
# )
self, result, run_id, uid = submit_api_call(
page_cls=AsrPage,
user=billing_account_user,
request_body={
"documents": [input_audio],
"selected_model": selected_model,
"google_translate_target": None,
"language": language,
},
user=bot.convo.bot_integration.user,
query_params={},
)
return result
# wait for the result
result.get(disable_sync_subtasks=False)
state = self.run_doc_sr(run_id, uid).to_dict()
# check for errors
return state
# err_msg = state.get(StateKeys.error_msg)
# if err_msg:
# raise HTTPException(
# status_code=500,
# detail={
# "id": run_id,
# "url": web_url,
# "created_at": created_at,
# "error": err_msg,
# },
# )
# else:
# # return updated state
# return {
# "id": run_id,
# "url": web_url,
# "created_at": created_at,
# "output": state,
# }
# return result


class ButtonIds:
Expand Down
Loading

0 comments on commit 4b772bd

Please sign in to comment.