diff --git a/auth_backend.py b/auth/auth_backend.py
similarity index 76%
rename from auth_backend.py
rename to auth/auth_backend.py
index 24dded058..d197ef213 100644
--- a/auth_backend.py
+++ b/auth/auth_backend.py
@@ -6,36 +6,37 @@
from starlette.concurrency import run_in_threadpool
from app_users.models import AppUser
-from daras_ai_v2.crypto import get_random_string, get_random_doc_id
+from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.db import FIREBASE_SESSION_COOKIE, ANONYMOUS_USER_COOKIE
from gooeysite.bg_db_conn import db_middleware
# quick and dirty way to bypass authentication for testing
-_forced_auth_user = []
+authlocal = []
@contextmanager
def force_authentication():
- user = AppUser.objects.create(
- is_anonymous=True, uid=get_random_doc_id(), balance=10**9
+ authlocal.append(
+ AppUser.objects.get_or_create(
+ email="tests@pytest.org",
+ defaults=dict(is_anonymous=True, uid=get_random_doc_id(), balance=10**9),
+ )[0]
)
try:
- _forced_auth_user.append(user)
- yield
+ yield authlocal[0]
finally:
- _forced_auth_user.clear()
+ authlocal.clear()
class SessionAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
- return await run_in_threadpool(authenticate, conn)
+ if authlocal:
+ return AuthCredentials(["authenticated"]), authlocal[0]
+ return await run_in_threadpool(_authenticate, conn)
@db_middleware
-def authenticate(conn):
- if _forced_auth_user:
- return AuthCredentials(["authenticated"]), _forced_auth_user[0]
-
+def _authenticate(conn):
session_cookie = conn.session.get(FIREBASE_SESSION_COOKIE)
if not session_cookie:
# Session cookie is unavailable. Check if anonymous user is available.
@@ -51,7 +52,7 @@ def authenticate(conn):
# Session cookie is unavailable. Force user to login.
return AuthCredentials(), None
- user = verify_session_cookie(session_cookie)
+ user = _verify_session_cookie(session_cookie)
if not user:
# Session cookie was invalid
conn.session.pop(FIREBASE_SESSION_COOKIE, None)
@@ -60,7 +61,7 @@ def authenticate(conn):
return AuthCredentials(["authenticated"]), user
-def verify_session_cookie(firebase_cookie: str) -> UserRecord | None:
+def _verify_session_cookie(firebase_cookie: str) -> UserRecord | None:
# Verify the session cookie. In this case an additional check is added to detect
# if the user's Firebase session was revoked, user deleted/disabled, etc.
try:
diff --git a/gooey_token_authentication1/token_authentication.py b/auth/token_authentication.py
similarity index 94%
rename from gooey_token_authentication1/token_authentication.py
rename to auth/token_authentication.py
index 4e81c42b0..b33bbbbd0 100644
--- a/gooey_token_authentication1/token_authentication.py
+++ b/auth/token_authentication.py
@@ -1,8 +1,10 @@
+import threading
+
from fastapi import Header
from fastapi.exceptions import HTTPException
from app_users.models import AppUser
-from auth_backend import _forced_auth_user
+from auth.auth_backend import authlocal
from daras_ai_v2 import db
from daras_ai_v2.crypto import PBKDF2PasswordHasher
@@ -15,9 +17,8 @@ def api_auth_header(
description=f"{auth_keyword} $GOOEY_API_KEY",
),
) -> AppUser:
- if _forced_auth_user:
- return _forced_auth_user[0]
-
+ if authlocal:
+ return authlocal[0]
return authenticate(authorization)
diff --git a/bots/models.py b/bots/models.py
index b4eb715c2..0f0ddce50 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -68,6 +68,7 @@ class Workflow(models.IntegerChoices):
RELATED_QNA_MAKER = (27, "Related QnA Maker")
RELATED_QNA_MAKER_DOC = (28, "Related QnA Maker Doc")
EMBEDDINGS = (29, "Embeddings")
+ BULK_RUNNER = (30, "Bulk Runner")
@property
def short_slug(self):
@@ -239,7 +240,7 @@ def submit_api_call(
kwds=dict(
page_cls=Workflow(self.workflow).page_cls,
query_params=dict(
- example_id=self.example_id, run_id=self.id, uid=self.uid
+ example_id=self.example_id, run_id=self.run_id, uid=self.uid
),
user=current_user,
request_body=request_body,
diff --git a/bots/tests.py b/bots/tests.py
index 0bc887389..1fb66e16b 100644
--- a/bots/tests.py
+++ b/bots/tests.py
@@ -14,7 +14,7 @@
CHATML_ROLE_ASSISSTANT = "assistant"
-def test_add_balance_direct():
+def test_add_balance_direct(transactional_db):
pk = AppUser.objects.create(balance=0, is_anonymous=False).pk
amounts = [[random.randint(-100, 10_000) for _ in range(100)] for _ in range(5)]
@@ -28,7 +28,7 @@ def worker(amts):
assert AppUser.objects.get(pk=pk).balance == sum(map(sum, amounts))
-def test_create_bot_integration_conversation_message():
+def test_create_bot_integration_conversation_message(transactional_db):
# Create a new BotIntegration with WhatsApp as the platform
bot_integration = BotIntegration.objects.create(
name="My Bot Integration",
diff --git a/conftest.py b/conftest.py
index 7bea1a865..ff61671aa 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,7 +1,17 @@
+import typing
+from functools import wraps
+from threading import Thread
+from unittest.mock import patch
+
import pytest
+from pytest_subtests import subtests
+from auth import auth_backend
+from celeryapp import app
+from daras_ai_v2.base import BasePage
-@pytest.fixture
+
+@pytest.fixture(scope="session")
def django_db_setup(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
from django.core.management import call_command
@@ -9,12 +19,131 @@ def django_db_setup(django_db_setup, django_db_blocker):
call_command("loaddata", "fixture.json")
-@pytest.fixture(
- # add this fixture to all tests
- autouse=True
-)
-def enable_db_access_for_all_tests(
- # enable transactional db
- transactional_db,
+@pytest.fixture
+def force_authentication():
+ with auth_backend.force_authentication() as user:
+ yield user
+
+
+app.conf.task_always_eager = True
+
+
+@pytest.fixture
+def mock_gui_runner():
+ with patch("celeryapp.tasks.gui_runner", _mock_gui_runner):
+ yield
+
+
+@app.task
+def _mock_gui_runner(
+ *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
- pass
+ sr = page_cls.run_doc_sr(run_id, uid)
+ sr.set(sr.parent.to_dict())
+ sr.save()
+
+
+@pytest.fixture
+def threadpool_subtest(subtests, max_workers: int = 8):
+ ts = []
+
+ def submit(fn, *args, **kwargs):
+ msg = "--".join(map(str, args))
+
+ @wraps(fn)
+ def runner(*args, **kwargs):
+ with subtests.test(msg=msg):
+ return fn(*args, **kwargs)
+
+ ts.append(Thread(target=runner, args=args, kwargs=kwargs))
+
+ yield submit
+
+ for i in range(0, len(ts), max_workers):
+ s = slice(i, i + max_workers)
+ for t in ts[s]:
+ t.start()
+ for t in ts[s]:
+ t.join()
+
+
+# class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker):
+# class _dj_db_wrapper:
+# def ensure_connection(self):
+# pass
+#
+#
+# @pytest.mark.tryfirst
+# def pytest_sessionstart(session):
+# # make the session threadsafe
+# _pytest.runner.SetupState = ThreadLocalSetupState
+#
+# # ensure that the fixtures (specifically finalizers) are threadsafe
+# _pytest.fixtures.FixtureDef = ThreadLocalFixtureDef
+#
+# django.test.utils._TestState = threading.local()
+#
+# # make the environment threadsafe
+# os.environ = ThreadLocalEnviron(os.environ)
+#
+#
+# @pytest.hookimpl
+# def pytest_runtestloop(session: _pytest.main.Session):
+# if session.testsfailed and not session.config.option.continue_on_collection_errors:
+# raise session.Interrupted(
+# "%d error%s during collection"
+# % (session.testsfailed, "s" if session.testsfailed != 1 else "")
+# )
+#
+# if session.config.option.collectonly:
+# return True
+#
+# num_threads = 10
+#
+# threadpool_items = [
+# item
+# for item in session.items
+# if any("run_in_threadpool" in marker.name for marker in item.iter_markers())
+# ]
+#
+# manager = pytest_django.plugin._blocking_manager
+# pytest_django.plugin._blocking_manager = DummyDatabaseBlocker()
+# try:
+# with manager.unblock():
+# for j in range(0, len(threadpool_items), num_threads):
+# s = slice(j, j + num_threads)
+# futs = []
+# for i, item in enumerate(threadpool_items[s]):
+# nextitem = (
+# threadpool_items[i + 1]
+# if i + 1 < len(threadpool_items)
+# else None
+# )
+# p = threading.Thread(
+# target=item.config.hook.pytest_runtest_protocol,
+# kwargs=dict(item=item, nextitem=nextitem),
+# )
+# p.start()
+# futs.append(p)
+# for p in futs:
+# p.join()
+# if session.shouldfail:
+# raise session.Failed(session.shouldfail)
+# if session.shouldstop:
+# raise session.Interrupted(session.shouldstop)
+# finally:
+# pytest_django.plugin._blocking_manager = manager
+#
+# session.items = [
+# item
+# for item in session.items
+# if not any("run_in_threadpool" in marker.name for marker in item.iter_markers())
+# ]
+# for i, item in enumerate(session.items):
+# nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
+# item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
+# if session.shouldfail:
+# raise session.Failed(session.shouldfail)
+# if session.shouldstop:
+# raise session.Interrupted(session.shouldstop)
+# return True
diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py
index 5b67d0445..96bd7734c 100644
--- a/daras_ai/image_input.py
+++ b/daras_ai/image_input.py
@@ -10,6 +10,9 @@
from daras_ai_v2 import settings
+if False:
+ from firebase_admin import storage
+
def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes:
img_cv2 = bytes_to_cv2_img(img_bytes)
@@ -54,15 +57,10 @@ def upload_file_from_bytes(
data: bytes,
content_type: str = None,
) -> str:
- from firebase_admin import storage
-
if not content_type:
content_type = mimetypes.guess_type(filename)[0]
content_type = content_type or "application/octet-stream"
-
- filename = safe_filename(filename)
- bucket = storage.bucket(settings.GS_BUCKET_NAME)
- blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}")
+ blob = storage_blob_for(filename)
blob.upload_from_string(data, content_type=content_type)
return blob.public_url
diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py
index 5befa1451..85047e770 100644
--- a/daras_ai_v2/all_pages.py
+++ b/daras_ai_v2/all_pages.py
@@ -3,6 +3,7 @@
from bots.models import Workflow
from daras_ai_v2.base import BasePage
+from recipes.BulkRunner import BulkRunnerPage
from recipes.ChyronPlant import ChyronPlantPage
from recipes.CompareLLM import CompareLLMPage
from recipes.CompareText2Img import CompareText2ImgPage
@@ -71,6 +72,7 @@
ImageSegmentationPage,
CompareUpscalerPage,
DocExtractPage,
+ BulkRunnerPage,
]
# exposed as API
diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py
index 7f15b06f4..da9a25a5e 100644
--- a/daras_ai_v2/api_examples_widget.py
+++ b/daras_ai_v2/api_examples_widget.py
@@ -6,7 +6,7 @@
from furl import furl
from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url
-from gooey_token_authentication1.token_authentication import auth_keyword
+from auth.token_authentication import auth_keyword
def get_filenames(request_body):
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index c0d594dd2..9744110c3 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -365,30 +365,35 @@ def get_sr_from_query_params_dict(self, query_params) -> SavedRun:
example_id, run_id, uid = extract_query_params(query_params)
return self.get_sr_from_query_params(example_id, run_id, uid)
- def get_sr_from_query_params(self, example_id, run_id, uid) -> SavedRun:
+ @classmethod
+ def get_sr_from_query_params(
+ cls, example_id: str, run_id: str, uid: str
+ ) -> SavedRun:
try:
if run_id and uid:
- sr = self.run_doc_sr(run_id, uid)
+ sr = cls.run_doc_sr(run_id, uid)
elif example_id:
- sr = self.example_doc_sr(example_id)
+ sr = cls.example_doc_sr(example_id)
else:
- sr = self.recipe_doc_sr()
+ sr = cls.recipe_doc_sr()
return sr
except SavedRun.DoesNotExist:
raise HTTPException(status_code=404)
- def recipe_doc_sr(self) -> SavedRun:
+ @classmethod
+ def recipe_doc_sr(cls) -> SavedRun:
return SavedRun.objects.get_or_create(
- workflow=self.workflow,
+ workflow=cls.workflow,
run_id__isnull=True,
uid__isnull=True,
example_id__isnull=True,
)[0]
+ @classmethod
def run_doc_sr(
- self, run_id: str, uid: str, create: bool = False, parent: SavedRun = None
+ cls, run_id: str, uid: str, create: bool = False, parent: SavedRun = None
) -> SavedRun:
- config = dict(workflow=self.workflow, uid=uid, run_id=run_id)
+ config = dict(workflow=cls.workflow, uid=uid, run_id=run_id)
if create:
return SavedRun.objects.get_or_create(
**config, defaults=dict(parent=parent)
@@ -396,8 +401,9 @@ def run_doc_sr(
else:
return SavedRun.objects.get(**config)
- def example_doc_sr(self, example_id: str, create: bool = False) -> SavedRun:
- config = dict(workflow=self.workflow, example_id=example_id)
+ @classmethod
+ def example_doc_sr(cls, example_id: str, create: bool = False) -> SavedRun:
+ config = dict(workflow=cls.workflow, example_id=example_id)
if create:
return SavedRun.objects.get_or_create(**config)[0]
else:
@@ -851,7 +857,7 @@ def _render(sr: SavedRun):
workflow=self.workflow,
hidden=False,
example_id__isnull=False,
- ).exclude()[:50]
+ )[:50]
grid_layout(3, example_runs, _render)
diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py
index 31865022b..5c3c2de59 100644
--- a/daras_ai_v2/doc_search_settings_widgets.py
+++ b/daras_ai_v2/doc_search_settings_widgets.py
@@ -1,3 +1,5 @@
+import typing
+
import gooey_ui as st
from daras_ai_v2 import settings
@@ -12,8 +14,18 @@ def is_user_uploaded_url(url: str) -> bool:
def document_uploader(
label: str,
- key="documents",
- accept=(".pdf", ".txt", ".docx", ".md", ".html", ".wav", ".ogg", ".mp3", ".aac"),
+ key: str = "documents",
+ accept: typing.Iterable[str] = (
+ ".pdf",
+ ".txt",
+ ".docx",
+ ".md",
+ ".html",
+ ".wav",
+ ".ogg",
+ ".mp3",
+ ".aac",
+ ),
):
st.write(label, className="gui-input")
documents = st.session_state.get(key) or []
@@ -45,6 +57,7 @@ def document_uploader(
accept=accept,
accept_multiple_files=True,
)
+ return st.session_state.get(key, [])
def doc_search_settings(
diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py
index 657d2898c..1baf64b47 100644
--- a/daras_ai_v2/search_ref.py
+++ b/daras_ai_v2/search_ref.py
@@ -3,12 +3,13 @@
from enum import Enum
import jinja2
+from typing_extensions import TypedDict
import gooey_ui
from daras_ai_v2.scrollable_html_widget import scrollable_html
-class SearchReference(typing.TypedDict):
+class SearchReference(TypedDict):
url: str
title: str
snippet: str
@@ -19,17 +20,23 @@ class CitationStyles(Enum):
number = "Numbers ( [1] [2] [3] ..)"
title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)"
url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)"
+ symbol = "Symbols ( [*] [†] [‡] ..)"
markdown = "Markdown ( [Source 1](https://source1.com) [Source 2](https://source2.com) [Source 3](https://source3.com) ..)"
html = "HTML ( Source 1 Source 2 Source 3 ..)"
slack_mrkdwn = "Slack mrkdwn ( ..)"
plaintext = "Plain Text / WhatsApp ( [Source 1 https://source1.com] [Source 2 https://source2.com] [Source 3 https://source3.com] ..)"
- number_markdown = " Markdown Numbers + Footnotes"
+ number_markdown = "Markdown Numbers + Footnotes"
number_html = "HTML Numbers + Footnotes"
number_slack_mrkdwn = "Slack mrkdown Numbers + Footnotes"
number_plaintext = "Plain Text / WhatsApp Numbers + Footnotes"
+ symbol_markdown = "Markdown Symbols + Footnotes"
+ symbol_html = "HTML Symbols + Footnotes"
+ symbol_slack_mrkdwn = "Slack mrkdown Symbols + Footnotes"
+ symbol_plaintext = "Plain Text / WhatsApp Symbols + Footnotes"
+
def remove_quotes(snippet: str) -> str:
return re.sub(r"[\"\']+", r'"', snippet).strip()
@@ -63,36 +70,65 @@ def apply_response_template(
match citation_style:
case CitationStyles.number | CitationStyles.number_plaintext:
cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
- case CitationStyles.number_html:
+ case CitationStyles.title:
+ cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
+ case CitationStyles.url:
+ cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
+ case CitationStyles.symbol | CitationStyles.symbol_plaintext:
cites = " ".join(
- html_link(f"[{ref_num}]", ref["url"])
+ f"[{generate_footnote_symbol(ref_num - 1)}]"
+ for ref_num in ref_map.keys()
+ )
+
+ case CitationStyles.markdown:
+ cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
+ case CitationStyles.html:
+ cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
+ case CitationStyles.slack_mrkdwn:
+ cites = " ".join(
+ ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
+ )
+ case CitationStyles.plaintext:
+ cites = " ".join(
+ f'[{ref["title"]} {ref["url"]}]'
for ref_num, ref in ref_map.items()
)
+
case CitationStyles.number_markdown:
cites = " ".join(
markdown_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
+ case CitationStyles.number_html:
+ cites = " ".join(
+ html_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
+ )
case CitationStyles.number_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
- case CitationStyles.title:
- cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
- case CitationStyles.url:
- cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
- case CitationStyles.markdown:
- cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
- case CitationStyles.html:
- cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
- case CitationStyles.slack_mrkdwn:
+
+ case CitationStyles.symbol_markdown:
cites = " ".join(
- ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
+ markdown_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
- case CitationStyles.plaintext:
+ case CitationStyles.symbol_html:
cites = " ".join(
- f'[{ref["title"]} {ref["url"]}]'
+ html_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ cites = " ".join(
+ slack_mrkdwn_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
for ref_num, ref in ref_map.items()
)
case None:
@@ -128,6 +164,31 @@ def apply_response_template(
for ref_num, ref in sorted(all_refs.items())
)
+ case CitationStyles.symbol_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+
for ref_num, ref in all_refs.items():
try:
template = ref["response_template"]
@@ -205,3 +266,9 @@ def render_output_with_refs(state, height):
for text in output_text:
html = render_text_with_refs(text, state.get("references", []))
scrollable_html(html, height=height)
+
+
+FOOTNOTE_SYMBOLS = ["*", "†", "‡", "§", "¶", "#", "♠", "♥", "♦", "♣", "✠", "☮", "☯", "✡"] # fmt: skip
+def generate_footnote_symbol(idx: int) -> str:
+ quotient, remainder = divmod(idx, len(FOOTNOTE_SYMBOLS))
+ return FOOTNOTE_SYMBOLS[remainder] * (quotient + 1)
diff --git a/gooey_ui/components.py b/gooey_ui/components.py
index f9b366d1c..ea3bbedda 100644
--- a/gooey_ui/components.py
+++ b/gooey_ui/components.py
@@ -241,7 +241,7 @@ def text_area(
key = md5_values(
"textarea", label, height, help, value, placeholder, label_visibility
)
- value = state.session_state.setdefault(key, value)
+ value = str(state.session_state.setdefault(key, value))
if label_visibility != "visible":
label = None
if disabled:
@@ -462,6 +462,10 @@ def json(value: typing.Any, expanded: bool = False, depth: int = 1):
).mount()
+def data_table(file_url: str):
+ return _node("data-table", fileUrl=file_url)
+
+
def table(df: "pd.DataFrame"):
state.RenderTreeNode(
name="table",
diff --git a/poetry.lock b/poetry.lock
index 32a70bdde..de91446e3 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -173,24 +173,24 @@ vine = ">=5.0.0"
[[package]]
name = "anyio"
-version = "4.0.0"
+version = "3.7.1"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.7"
files = [
- {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"},
- {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"},
+ {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"},
+ {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"},
]
[package.dependencies]
-exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""}
+exceptiongroup = {version = "*", markers = "python_version < \"3.11\""}
idna = ">=2.8"
sniffio = ">=1.1"
[package.extras]
-doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"]
-test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
-trio = ["trio (>=0.22)"]
+doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"]
+test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
+trio = ["trio (<0.22)"]
[[package]]
name = "appdirs"
@@ -1483,11 +1483,8 @@ files = [
[package.dependencies]
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
-grpcio = [
- {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
- {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
-]
-grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}
+grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}
+grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
@@ -1604,8 +1601,8 @@ files = [
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-cloud-core = ">=1.4.1,<3.0.0dev"
proto-plus = [
- {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
@@ -1623,8 +1620,8 @@ files = [
[package.dependencies]
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
proto-plus = [
- {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
@@ -1663,8 +1660,8 @@ files = [
[package.dependencies]
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
proto-plus = [
- {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
@@ -3395,12 +3392,9 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.21.2", markers = "python_version >= \"3.10\""},
- {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
- {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
- {version = ">=1.17.0", markers = "python_version >= \"3.7\""},
- {version = ">=1.17.3", markers = "python_version >= \"3.8\""},
+ {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
+ {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
]
[[package]]
@@ -3458,8 +3452,8 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -4238,6 +4232,21 @@ pytest = ">=5.4.0"
docs = ["sphinx", "sphinx-rtd-theme"]
testing = ["Django", "django-configurations (>=2.0)"]
+[[package]]
+name = "pytest-subtests"
+version = "0.11.0"
+description = "unittest subTest() support and subtests fixture"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-subtests-0.11.0.tar.gz", hash = "sha256:51865c88457545f51fb72011942f0a3c6901ee9e24cbfb6d1b9dc1348bafbe37"},
+ {file = "pytest_subtests-0.11.0-py3-none-any.whl", hash = "sha256:453389984952eec85ab0ce0c4f026337153df79587048271c7fd0f49119c07e4"},
+]
+
+[package.dependencies]
+attrs = ">=19.2.0"
+pytest = ">=7.0"
+
[[package]]
name = "pytest-xdist"
version = "3.3.1"
@@ -5204,7 +5213,7 @@ files = [
]
[package.dependencies]
-greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"}
+greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
[package.extras]
aiomysql = ["aiomysql", "greenlet (!=0.4.17)"]
@@ -6129,4 +6138,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
-content-hash = "e1aa198ba112e95195815327669b1e7687428ab0510c5485f057442a207b982e"
+content-hash = "8cb6f5a826bc1bbd06c65a871e027cf967178cd2e8e6c33c262b80a43a772cfe"
diff --git a/pyproject.toml b/pyproject.toml
index 3ea6240be..6c697412c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,12 +18,12 @@ pandas = "^2.0.1"
google-cloud-firestore = "^2.7.0"
replicate = "^0.4.0"
fastapi = "^0.85.0"
-uvicorn = {extras = ["standard"], version = "^0.18.3"}
+uvicorn = { extras = ["standard"], version = "^0.18.3" }
firebase-admin = "^6.0.0"
# mediapipe for M1 macs
-mediapipe-silicon = {version = "^0.8.11", markers = "platform_machine == 'arm64'", platform = "darwin"}
+mediapipe-silicon = { version = "^0.8.11", markers = "platform_machine == 'arm64'", platform = "darwin" }
# mediapipe for others
-mediapipe = {version = "^0.8.11", markers = "platform_machine != 'arm64'"}
+mediapipe = { version = "^0.8.11", markers = "platform_machine != 'arm64'" }
furl = "^2.1.3"
itsdangerous = "^2.1.2"
pytest = "^7.2.0"
@@ -53,7 +53,7 @@ llama-index = "^0.5.27"
nltk = "^3.8.1"
Jinja2 = "^3.1.2"
Django = "^4.2"
-django-phonenumber-field = {extras = ["phonenumberslite"], version = "^7.0.2"}
+django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" }
gunicorn = "^20.1.0"
psycopg2-binary = "^2.9.6"
whitenoise = "^6.4.0"
@@ -75,6 +75,8 @@ tabulate = "^0.9.0"
deepgram-sdk = "^2.11.0"
scipy = "^1.11.2"
rank-bm25 = "^0.2.2"
+pytest-subtests = "^0.11.0"
+anyio = "^3.4.0"
[tool.poetry.group.dev.dependencies]
watchdog = "^2.1.9"
diff --git a/pytest.ini b/pytest.ini
index 55040e469..4a9863462 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,4 +1,4 @@
[pytest]
-addopts = --tb=native -vv -n 16 --disable-warnings
+addopts = --tb=native --disable-warnings
DJANGO_SETTINGS_MODULE = daras_ai_v2.settings
-python_files = tests.py test_*.py *_tests.py
\ No newline at end of file
+python_files = tests.py test_*.py *_tests.py
diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py
new file mode 100644
index 000000000..5d5f9b554
--- /dev/null
+++ b/recipes/BulkRunner.py
@@ -0,0 +1,357 @@
+import io
+import typing
+
+import pandas as pd
+import requests
+from fastapi import HTTPException
+from furl import furl
+from pydantic import BaseModel
+
+import gooey_ui as st
+from bots.models import Workflow
+from daras_ai.image_input import upload_file_from_bytes
+from daras_ai_v2.base import BasePage
+from daras_ai_v2.doc_search_settings_widgets import document_uploader
+from daras_ai_v2.functional import map_parallel
+from daras_ai_v2.query_params_util import extract_query_params
+from recipes.DocSearch import render_documents
+
+CACHED_COLUMNS = "__cached_columns"
+
+
+class BulkRunnerPage(BasePage):
+ title = "Bulk Runner"
+ workflow = Workflow.BULK_RUNNER
+ slug_versions = ["bulk-runner", "bulk"]
+
+ class RequestModel(BaseModel):
+ documents: list[str]
+ run_urls: list[str]
+
+ input_columns: dict[str, str]
+ output_columns: dict[str, str]
+
+ class ResponseModel(BaseModel):
+ output_documents: list[str]
+
+ def fields_to_save(self) -> [str]:
+ return super().fields_to_save() + [CACHED_COLUMNS]
+
+ def render_form_v2(self):
+ from daras_ai_v2.all_pages import page_slug_map, normalize_slug
+
+ run_urls = st.session_state.get("run_urls", "")
+ st.session_state.setdefault("__run_urls", "\n".join(run_urls))
+ run_urls = (
+ st.text_area("##### Run URL(s)", key="__run_urls").strip().splitlines()
+ )
+ st.session_state["run_urls"] = run_urls
+
+ files = document_uploader(
+ "##### Upload a File",
+ accept=(".csv", ".xlsx", ".xls", ".json", ".tsv", ".xml"),
+ )
+
+ if files:
+ st.session_state[CACHED_COLUMNS] = list(
+ {
+ col: None
+ for df in map_parallel(_read_df, files)
+ for col in df.columns
+ if not col.startswith("Unnamed:")
+ }
+ )
+ else:
+ st.session_state.pop(CACHED_COLUMNS, None)
+
+ required_input_fields = {}
+ optional_input_fields = {}
+ output_fields = {}
+
+ for url in run_urls:
+ f = furl(url)
+ slug = f.path.segments[0]
+ try:
+ page_cls = page_slug_map[normalize_slug(slug)]
+ except KeyError as e:
+ st.error(repr(e))
+ continue
+
+ example_id, run_id, uid = extract_query_params(f.query.params)
+ try:
+ sr = page_cls.get_sr_from_query_params(example_id, run_id, uid)
+ except HTTPException as e:
+ st.error(repr(e))
+ continue
+
+ schema = page_cls.RequestModel.schema(ref_template="{model}")
+ for field, model_field in page_cls.RequestModel.__fields__.items():
+ if model_field.required:
+ input_fields = required_input_fields
+ else:
+ input_fields = optional_input_fields
+ field_props = schema["properties"][field]
+ title = field_props["title"]
+ keys = None
+ if is_arr(field_props):
+ try:
+ ref = field_props["items"]["$ref"]
+ props = schema["definitions"][ref]["properties"]
+ keys = {k: prop["title"] for k, prop in props.items()}
+ except KeyError:
+ try:
+ keys = {k: k for k in sr.state[field][0].keys()}
+ except (KeyError, IndexError, AttributeError):
+ pass
+ elif field_props.get("type") == "object":
+ try:
+ keys = {k: k for k in sr.state[field].keys()}
+ except (KeyError, AttributeError):
+ pass
+ if keys:
+ for k, ktitle in keys.items():
+ input_fields[f"{field}.{k}"] = f"{title}.{ktitle}"
+ else:
+ input_fields[field] = title
+
+ schema = page_cls.ResponseModel.schema()
+ output_fields |= {
+ field: schema["properties"][field]["title"]
+ for field, model_field in page_cls.ResponseModel.__fields__.items()
+ }
+
+ columns = st.session_state.get(CACHED_COLUMNS, [])
+ if not columns:
+ return
+
+ for file in files:
+ st.data_table(file)
+
+ if not (required_input_fields or optional_input_fields):
+ return
+
+ col1, col2 = st.columns(2)
+
+ with col1:
+ st.write("##### Input Columns")
+
+ input_columns_old = st.session_state.pop("input_columns", {})
+ input_columns_new = st.session_state.setdefault("input_columns", {})
+
+ column_options = [None, *columns]
+ for fields in (required_input_fields, optional_input_fields):
+ for field, title in fields.items():
+ col = st.selectbox(
+ label="`" + title + "`",
+ options=column_options,
+ key="--input-mapping:" + field,
+ default_value=input_columns_old.get(field),
+ )
+ if col:
+ input_columns_new[field] = col
+ st.write("---")
+
+ with col2:
+ st.write("##### Output Columns")
+
+ output_columns_old = st.session_state.pop("output_columns", {})
+ output_columns_new = st.session_state.setdefault("output_columns", {})
+
+ prev_fields = st.session_state.get("--prev-output-fields")
+ fields = {**output_fields, "error_msg": "Error Msg", "run_url": "Run URL"}
+ did_change = prev_fields is not None and prev_fields != fields
+ st.session_state["--prev-output-fields"] = fields
+ for field, title in fields.items():
+ col = st.text_input(
+ label="`" + title + "`",
+ key="--output-mapping:" + field,
+ value=output_columns_old.get(field, title if did_change else None),
+ )
+ if col:
+ output_columns_new[field] = col
+
+ def render_example(self, state: dict):
+ render_documents(state)
+
+ def render_output(self):
+ files = st.session_state.get("output_documents", [])
+ for file in files:
+ st.write(file)
+ st.data_table(file)
+
+ def run_v2(
+ self,
+ request: "BulkRunnerPage.RequestModel",
+ response: "BulkRunnerPage.ResponseModel",
+ ) -> typing.Iterator[str | None]:
+ response.output_documents = []
+
+ for doc_ix, doc in enumerate(request.documents):
+ df = _read_df(doc)
+ in_recs = df.to_dict(orient="records")
+ out_recs = []
+
+ f = upload_file_from_bytes(
+ filename=f"bulk-runner-{doc_ix}-0-0.csv",
+ data=df.to_csv(index=False).encode(),
+ content_type="text/csv",
+ )
+ response.output_documents.append(f)
+
+ df_slices = list(slice_request_df(df, request))
+ for slice_ix, (df_ix, arr_len) in enumerate(df_slices):
+ rec_ix = len(out_recs)
+ out_recs.extend(in_recs[df_ix : df_ix + arr_len])
+
+ for url_ix, f, request_body, page_cls in build_requests_for_df(
+ df, request, df_ix, arr_len
+ ):
+ progress = round(
+ (slice_ix + url_ix)
+ / (len(df_slices) + len(request.run_urls))
+ * 100
+ )
+ yield f"{progress}%"
+
+ example_id, run_id, uid = extract_query_params(f.query.params)
+ sr = page_cls.get_sr_from_query_params(example_id, run_id, uid)
+
+ result, sr = sr.submit_api_call(
+ current_user=self.request.user, request_body=request_body
+ )
+ result.get(disable_sync_subtasks=False)
+ sr.refresh_from_db()
+ state = sr.to_dict()
+ state["run_url"] = sr.get_app_url()
+ state["error_msg"] = sr.error_msg
+
+ for field, col in request.output_columns.items():
+ if len(request.run_urls) > 1:
+ col = f"({url_ix + 1}) {col}"
+ out_val = state.get(field)
+ if isinstance(out_val, list):
+ for arr_ix, item in enumerate(out_val):
+ if len(out_recs) <= rec_ix + arr_ix:
+ out_recs.append({})
+ if isinstance(item, dict):
+ for key, val in item.items():
+ out_recs[rec_ix + arr_ix][f"{col}.{key}"] = str(
+ val
+ )
+ else:
+ out_recs[rec_ix + arr_ix][col] = str(item)
+ elif isinstance(out_val, dict):
+ for key, val in out_val.items():
+ if isinstance(val, list):
+ for arr_ix, item in enumerate(val):
+ if len(out_recs) <= rec_ix + arr_ix:
+ out_recs.append({})
+ out_recs[rec_ix + arr_ix][f"{col}.{key}"] = str(
+ item
+ )
+ else:
+ out_recs[rec_ix][f"{col}.{key}"] = str(val)
+ else:
+ out_recs[rec_ix][col] = str(out_val)
+
+ out_df = pd.DataFrame.from_records(out_recs)
+ f = upload_file_from_bytes(
+ filename=f"bulk-runner-{doc_ix}-{url_ix}-{df_ix}.csv",
+ data=out_df.to_csv(index=False).encode(),
+ content_type="text/csv",
+ )
+ response.output_documents[doc_ix] = f
+
+
+def build_requests_for_df(df, request, df_ix, arr_len):
+ from daras_ai_v2.all_pages import page_slug_map, normalize_slug
+
+ for url_ix, url in enumerate(request.run_urls):
+ f = furl(url)
+ slug = f.path.segments[0]
+ page_cls = page_slug_map[normalize_slug(slug)]
+ schema = page_cls.RequestModel.schema()
+ properties = schema["properties"]
+
+ request_body = {}
+ for field, col in request.input_columns.items():
+ parts = field.split(".")
+ field_props = properties.get(parts[0]) or properties.get(parts)
+ if is_arr(field_props):
+ arr = request_body.setdefault(parts[0], [])
+ for arr_ix in range(arr_len):
+ value = df.at[df_ix + arr_ix, col]
+ if len(parts) > 1:
+ if len(arr) <= arr_ix:
+ arr.append({})
+ arr[arr_ix][parts[1]] = value
+ else:
+ if len(arr) <= arr_ix:
+ arr.append(None)
+ arr[arr_ix] = value
+ elif len(parts) > 1 and field_props.get("type") == "object":
+ obj = request_body.setdefault(parts[0], {})
+ obj[parts[1]] = df.at[df_ix, col]
+ else:
+ request_body[field] = df.at[df_ix, col]
+ # for validation
+ request_body = page_cls.RequestModel.parse_obj(request_body).dict()
+
+ yield url_ix, f, request_body, page_cls
+
+
+def slice_request_df(df, request):
+ from daras_ai_v2.all_pages import page_slug_map, normalize_slug
+
+ non_array_cols = set()
+ for url_ix, url in enumerate(request.run_urls):
+ f = furl(url)
+ slug = f.path.segments[0]
+ page_cls = page_slug_map[normalize_slug(slug)]
+ schema = page_cls.RequestModel.schema()
+ properties = schema["properties"]
+
+ for field, col in request.input_columns.items():
+ if is_arr(properties.get(field.split(".")[0])):
+ non_array_cols.add(col)
+ non_array_df = df[list(non_array_cols)]
+
+ df_ix = 0
+ while df_ix < len(df):
+ arr_len = 1
+ while df_ix + arr_len < len(df):
+ if not non_array_df.iloc[df_ix + arr_len].isnull().all():
+ break
+ arr_len += 1
+ yield df_ix, arr_len
+ df_ix += arr_len
+
+
+def is_arr(field_props: dict) -> bool:
+ try:
+ return field_props["type"] == "array"
+ except KeyError:
+ for props in field_props.get("anyOf", []):
+ if props["type"] == "array":
+ return True
+ return False
+
+
+def _read_df(f: str) -> "pd.DataFrame":
+ import pandas as pd
+
+ r = requests.get(f)
+ r.raise_for_status()
+ if f.endswith(".csv"):
+ df = pd.read_csv(io.StringIO(r.text))
+ elif f.endswith(".xlsx") or f.endswith(".xls"):
+ df = pd.read_excel(io.BytesIO(r.content))
+ elif f.endswith(".json"):
+ df = pd.read_json(io.StringIO(r.text))
+ elif f.endswith(".tsv"):
+ df = pd.read_csv(io.StringIO(r.text), sep="\t")
+ elif f.endswith(".xml"):
+ df = pd.read_xml(io.StringIO(r.text))
+ else:
+ raise ValueError(f"Unsupported file type: {f}")
+ return df.dropna(how="all", axis=1).dropna(how="all", axis=0)
diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py
index a30396907..f523ba14e 100644
--- a/recipes/DeforumSD.py
+++ b/recipes/DeforumSD.py
@@ -3,6 +3,7 @@
from django.db.models import TextChoices
from pydantic import BaseModel
+from typing_extensions import TypedDict
import gooey_ui as st
from app_users.models import AppUser
@@ -21,7 +22,7 @@ class AnimationModels(TextChoices):
epicdream = ("epicdream.safetensors", "epiCDream (epinikion)")
-class _AnimationPrompt(typing.TypedDict):
+class _AnimationPrompt(TypedDict):
frame: str
prompt: str
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 4c8895901..d5f405ed6 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -244,6 +244,7 @@ class ResponseModel(BaseModel):
final_prompt: str
raw_input_text: str | None
raw_output_text: list[str] | None
+ raw_tts_text: list[str] | None
output_text: list[str]
# tts
@@ -734,11 +735,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
source_language="en",
target_language=request.user_language,
)
-
- tts_text = [
- "".join(snippet for snippet, _ in parse_refs(text, references))
- for text in output_text
- ]
+ state["raw_tts_text"] = [
+ "".join(snippet for snippet, _ in parse_refs(text, references))
+ for text in output_text
+ ]
if references:
citation_style = (
@@ -754,7 +754,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
if not request.tts_provider:
return
tts_state = dict(state)
- for text in tts_text:
+ for text in state.get("raw_tts_text", state["raw_output_text"]):
tts_state["text_prompt"] = text
yield from TextToSpeechPage().run(tts_state)
state["output_audio"].append(tts_state["audio_url"])
diff --git a/routers/api.py b/routers/api.py
index 47b7989a1..62514cc14 100644
--- a/routers/api.py
+++ b/routers/api.py
@@ -29,7 +29,7 @@
BasePage,
StateKeys,
)
-from gooey_token_authentication1.token_authentication import api_auth_header
+from auth.token_authentication import api_auth_header
app = APIRouter()
diff --git a/routers/root.py b/routers/root.py
index 386cc9813..42627572e 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -21,9 +21,6 @@
import gooey_ui as st
from app_users.models import AppUser
-from auth_backend import (
- FIREBASE_SESSION_COOKIE,
-)
from daras_ai.image_input import upload_file_from_bytes, safe_filename
from daras_ai_v2 import settings
from daras_ai_v2.all_pages import all_api_pages, normalize_slug, page_slug_map
@@ -32,6 +29,7 @@
RedirectException,
)
from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts
+from daras_ai_v2.db import FIREBASE_SESSION_COOKIE
from daras_ai_v2.meta_content import build_meta_tags
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.settings import templates
diff --git a/routers/slack.py b/routers/slack.py
index 75ed8465e..20e19fec6 100644
--- a/routers/slack.py
+++ b/routers/slack.py
@@ -14,7 +14,6 @@
from bots.tasks import create_personal_channels_for_all_members
from daras_ai_v2 import settings
from daras_ai_v2.bots import _on_msg, request_json, request_urlencoded_body
-from daras_ai_v2.search_ref import parse_refs
from daras_ai_v2.slack_bot import (
SlackBot,
invite_bot_account_to_channel,
@@ -23,8 +22,6 @@
fetch_user_info,
parse_slack_response,
)
-from gooey_token_authentication1.token_authentication import auth_keyword
-from recipes.VideoBots import VideoBotsPage
router = APIRouter()
@@ -368,7 +365,6 @@ def slack_auth_header(
@router.get("/__/slack/get-response-for-msg/{msg_id}/")
def slack_get_response_for_msg_id(
msg_id: str,
- remove_refs: bool = True,
slack_user: dict = Depends(slack_auth_header),
):
try:
@@ -384,11 +380,13 @@ def slack_get_response_for_msg_id(
):
raise HTTPException(403, "Not authorized")
- output_text = response_msg.saved_run.state.get("output_text")
+ state = response_msg.saved_run.state
+ output_text = (
+ state.get("raw_tts_text")
+ or state.get("raw_output_text")
+ or state.get("output_text")
+ )
if not output_text:
return {"status": "no_output"}
- content = output_text[0]
- if remove_refs:
- content = "".join(snippet for snippet, _ in parse_refs(content, []))
- return {"status": "ok", "content": content}
+ return {"status": "ok", "content": output_text[0]}
diff --git a/scripts/create_fixture.py b/scripts/create_fixture.py
index d84495221..9482cef64 100644
--- a/scripts/create_fixture.py
+++ b/scripts/create_fixture.py
@@ -1,9 +1,24 @@
-from django.core.management import call_command
+import sys
+
+from django.core import serializers
from bots.models import SavedRun
def run():
- qs = SavedRun.objects.filter(run_id__isnull=True).values_list("pk", flat=True)
- pks = ",".join(map(str, qs))
- call_command("dumpdata", "bots.SavedRun", "--pks", pks, "--output", "fixture.json")
+ qs = SavedRun.objects.filter(run_id__isnull=True)
+ with open("fixture.json", "w") as f:
+ serializers.serialize(
+ "json",
+ get_objects(qs),
+ indent=2,
+ stream=f,
+ progress_output=sys.stdout,
+ object_count=qs.count(),
+ )
+
+
+def get_objects(qs):
+ for obj in qs:
+ obj.parent = None
+ yield obj
diff --git a/server.py b/server.py
index 1e6d03e1e..078b1508b 100644
--- a/server.py
+++ b/server.py
@@ -1,3 +1,5 @@
+import logging
+
import anyio
from decouple import config
@@ -18,7 +20,7 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
-from auth_backend import (
+from auth.auth_backend import (
SessionAuthBackend,
)
from daras_ai_v2 import settings
@@ -67,10 +69,14 @@ async def health():
@app.add_middleware
def request_time_middleware(app):
+ logger = logging.getLogger("uvicorn.time")
+
async def middleware(scope, receive, send):
start_time = time()
await app(scope, receive, send)
response_time = (time() - start_time) * 1000
- print(f"{scope.get('method')} {scope.get('path')} - {response_time:.3f} ms")
+ logger.info(
+ f"{scope.get('method')} {scope.get('path')} - {response_time:.3f} ms"
+ )
return middleware
diff --git a/tests/test_apis.py b/tests/test_apis.py
index d3e1d8b7d..6412e9eda 100644
--- a/tests/test_apis.py
+++ b/tests/test_apis.py
@@ -1,10 +1,10 @@
import typing
import pytest
-from fastapi.testclient import TestClient
+from starlette.testclient import TestClient
-from auth_backend import force_authentication
-from daras_ai_v2 import db
+from auth.auth_backend import force_authentication
+from bots.models import SavedRun, Workflow
from daras_ai_v2.all_pages import all_test_pages
from daras_ai_v2.base import (
BasePage,
@@ -12,21 +12,77 @@
)
from server import app
+MAX_WORKERS = 20
+
client = TestClient(app)
-@pytest.mark.parametrize("page_cls", all_test_pages)
-def test_apis_basic(page_cls: typing.Type[BasePage]):
- page = page_cls()
- state = db.get_or_create_doc(db.get_doc_ref(page.slug_versions[0])).to_dict()
+@pytest.mark.django_db
+def test_apis_sync(mock_gui_runner, force_authentication, threadpool_subtest):
+ for page_cls in all_test_pages:
+ threadpool_subtest(_test_api_sync, page_cls)
+
+
+def _test_api_sync(page_cls: typing.Type[BasePage]):
+ state = page_cls.recipe_doc_sr().state
+ r = client.post(
+ f"/v2/{page_cls.slug_versions[0]}/",
+ json=get_example_request_body(page_cls.RequestModel, state),
+ headers={"Authorization": f"Token None"},
+ allow_redirects=False,
+ )
+ assert r.status_code == 200, r.text
+
+
+@pytest.mark.django_db
+def test_apis_async(mock_gui_runner, force_authentication, threadpool_subtest):
+ for page_cls in all_test_pages:
+ threadpool_subtest(_test_api_async, page_cls)
+
+
+def _test_api_async(page_cls: typing.Type[BasePage]):
+ state = page_cls.recipe_doc_sr().state
+
+ r = client.post(
+ f"/v3/{page_cls.slug_versions[0]}/async/",
+ json=get_example_request_body(page_cls.RequestModel, state),
+ headers={"Authorization": f"Token None"},
+ allow_redirects=False,
+ )
+ assert r.status_code == 202, r.text
+
+ status_url = r.json()["status_url"]
+
+ r = client.get(
+ status_url,
+ headers={"Authorization": f"Token None"},
+ allow_redirects=False,
+ )
+ assert r.status_code == 200, r.text
+
+ data = r.json()
+ assert data.get("status") == "completed", data
+ assert data.get("output") is not None, data
+
+
+@pytest.mark.django_db
+def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest):
+ for page in all_test_pages:
+ for sr in SavedRun.objects.filter(
+ workflow=page.workflow,
+ hidden=False,
+ example_id__isnull=False,
+ ):
+ threadpool_subtest(_test_apis_examples, sr)
- with force_authentication():
- r = client.post(
- page.endpoint,
- json=get_example_request_body(page.RequestModel, state),
- headers={"Authorization": f"Token None"},
- allow_redirects=False,
- )
- print(r.content)
- assert r.status_code == 200
+def _test_apis_examples(sr: SavedRun):
+ state = sr.state
+ page_cls = Workflow(sr.workflow).page_cls
+ r = client.post(
+ f"/v2/{page_cls.slug_versions[0]}/?example_id={sr.example_id}",
+ json=get_example_request_body(page_cls.RequestModel, state),
+ headers={"Authorization": f"Token None"},
+ allow_redirects=False,
+ )
+ assert r.status_code == 200, r.text
diff --git a/tests/test_checkout.py b/tests/test_checkout.py
index 5ea7c326b..5e822b775 100644
--- a/tests/test_checkout.py
+++ b/tests/test_checkout.py
@@ -1,7 +1,7 @@
import pytest
from fastapi.testclient import TestClient
-from auth_backend import force_authentication
+from auth.auth_backend import force_authentication
from routers.billing import available_subscriptions
from server import app
@@ -9,7 +9,7 @@
@pytest.mark.parametrize("subscription", available_subscriptions.keys())
-def test_create_checkout_session(subscription: str):
+def test_create_checkout_session(subscription: str, transactional_db):
with force_authentication():
form_data = {"lookup_key": subscription}
if subscription == "addon":
diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py
index 009fb836c..02006b92f 100644
--- a/tests/test_public_endpoints.py
+++ b/tests/test_public_endpoints.py
@@ -1,20 +1,20 @@
import pytest
-import requests
-from furl import furl
from starlette.routing import Route
from starlette.testclient import TestClient
-from auth_backend import force_authentication
-from daras_ai_v2 import settings
+from bots.models import SavedRun
from daras_ai_v2.all_pages import all_api_pages
from daras_ai_v2.tabs_widget import MenuTabs
from routers import facebook
+from routers.slack import slack_connect_redirect_shortcuts, slack_connect_redirect
from server import app
client = TestClient(app)
excluded_endpoints = [
facebook.fb_webhook_verify.__name__, # gives 403
+ slack_connect_redirect.__name__,
+ slack_connect_redirect_shortcuts.__name__,
"get_run_status", # needs query params
]
@@ -30,10 +30,10 @@
]
+@pytest.mark.django_db
@pytest.mark.parametrize("path", route_paths)
def test_all_get(path):
r = client.get(path, allow_redirects=False)
- print(r.content)
assert r.ok
@@ -41,14 +41,29 @@ def test_all_get(path):
tabs = list(MenuTabs.paths.values())
-@pytest.mark.parametrize("tab", tabs)
+@pytest.mark.django_db
@pytest.mark.parametrize("slug", page_slugs)
+@pytest.mark.parametrize("tab", tabs)
def test_page_slugs(slug, tab):
- with force_authentication():
- r = requests.post(
- str(furl(settings.API_BASE_URL) / slug / tab),
- json={},
- )
- # r = client.post(os.path.join(slug, tab), json={}, allow_redirects=True)
- print(r.content)
+ r = client.post(
+ f"/{slug}/{tab}",
+ json={},
+ allow_redirects=True,
+ )
assert r.status_code == 200
+
+
+@pytest.mark.django_db
+def test_example_slugs(subtests):
+ for page_cls in all_api_pages:
+ for tab in tabs:
+ for example_id in SavedRun.objects.filter(
+ workflow=page_cls.workflow,
+ hidden=False,
+ example_id__isnull=False,
+ ).values_list("example_id", flat=True):
+ slug = page_cls.slug_versions[0]
+ url = f"/{slug}/{tab}?example_id={example_id}"
+ with subtests.test(msg=url):
+ r = client.post(url, json={}, allow_redirects=True)
+ assert r.status_code == 200
diff --git a/tests/test_search_refs.py b/tests/test_search_refs.py
index 5374db260..00698bf88 100644
--- a/tests/test_search_refs.py
+++ b/tests/test_search_refs.py
@@ -1,4 +1,6 @@
-from daras_ai_v2.search_ref import parse_refs
+import pytest
+
+from daras_ai_v2.search_ref import parse_refs, generate_footnote_symbol
def test_ref_parser():
@@ -126,3 +128,21 @@ def test_ref_parser():
},
),
]
+
+
+def test_generate_footnote_symbol():
+ assert generate_footnote_symbol(0) == "*"
+ assert generate_footnote_symbol(1) == "†"
+ assert generate_footnote_symbol(13) == "✡"
+ assert generate_footnote_symbol(14) == "**"
+ assert generate_footnote_symbol(15) == "††"
+ assert generate_footnote_symbol(27) == "✡✡"
+ assert generate_footnote_symbol(28) == "***"
+ assert generate_footnote_symbol(29) == "†††"
+ assert generate_footnote_symbol(41) == "✡✡✡"
+ assert generate_footnote_symbol(70) == "******"
+ assert generate_footnote_symbol(71) == "††††††"
+
+ # testing with non-integer index
+ with pytest.raises(TypeError):
+ generate_footnote_symbol(1.5)
diff --git a/tests/test_slack.py b/tests/test_slack.py
index ebe72ca96..3a5802c8d 100644
--- a/tests/test_slack.py
+++ b/tests/test_slack.py
@@ -33,6 +33,7 @@ def test_slack_get_response_for_msg_id(transactional_db):
platform_msg_id="response-msg",
saved_run=SavedRun.objects.create(
state=VideoBotsPage.ResponseModel(
+ raw_tts_text=["hello, world!"],
output_text=["hello, world! [2]"],
final_prompt="",
output_audio=[],
diff --git a/tests/test_translation.py b/tests/test_translation.py
index 1079036dd..95cf2a539 100644
--- a/tests/test_translation.py
+++ b/tests/test_translation.py
@@ -1,21 +1,19 @@
-import pytest
-
from daras_ai_v2.asr import run_google_translate
TRANSLATION_TESTS = [
# hindi romanized
(
"Hi Sir Mera khet me mircha ke ped me fal gal Kar gir hai to iske liye ham kon sa dawa de please help me",
- "Hi sir the fruits of chilli tree in my field have rotted and fallen so what medicine should we give for this please help",
+ "hi sir the fruit of the chilli tree in my field has rotted and fallen so what medicine should we give for this please help",
),
(
"Mirchi ka ped",
- "chili tree",
+ "chilli tree",
),
# hindi
(
"ान का नर्सरी खेत में रोकने के लिए कितने दिन में तैयार हो जाता है",
- "in how many days the corn nursery is ready to stop in the field",
+ "in how many days does the seed nursery become ready to be planted in the field?",
),
# telugu
(
@@ -44,8 +42,12 @@
]
-@pytest.mark.parametrize("text, expected", TRANSLATION_TESTS)
-def test_run_google_translate(text: str, expected: str):
+def test_run_google_translate(threadpool_subtest):
+ for text, expected in TRANSLATION_TESTS:
+ threadpool_subtest(_test_run_google_translate, text, expected)
+
+
+def _test_run_google_translate(text: str, expected: str):
actual = run_google_translate([text], "en")[0]
assert (
actual.replace(".", "").replace(",", "").strip().lower()
diff --git a/url_shortener/tests.py b/url_shortener/tests.py
index 5e028329a..2589434f1 100644
--- a/url_shortener/tests.py
+++ b/url_shortener/tests.py
@@ -9,14 +9,14 @@
client = TestClient(app)
-def test_url_shortener():
+def test_url_shortener(transactional_db):
surl = ShortenedURL.objects.create(url=TEST_URL)
short_url = surl.shortened_url()
r = client.get(short_url, allow_redirects=False)
assert r.is_redirect and r.headers["location"] == TEST_URL
-def test_url_shortener_max_clicks():
+def test_url_shortener_max_clicks(transactional_db):
surl = ShortenedURL.objects.create(url=TEST_URL, max_clicks=5)
short_url = surl.shortened_url()
for _ in range(5):
@@ -26,14 +26,14 @@ def test_url_shortener_max_clicks():
assert r.status_code == 410
-def test_url_shortener_disabled():
+def test_url_shortener_disabled(transactional_db):
surl = ShortenedURL.objects.create(url=TEST_URL, disabled=True)
short_url = surl.shortened_url()
r = client.get(short_url, allow_redirects=False)
assert r.status_code == 410
-def test_url_shortener_create_atomic():
+def test_url_shortener_create_atomic(transactional_db):
def create(_):
return [
ShortenedURL.objects.create(url=TEST_URL).shortened_url()
@@ -43,7 +43,7 @@ def create(_):
assert len(set(flatmap_parallel(create, range(5)))) == 500
-def test_url_shortener_clicks_decrement_atomic():
+def test_url_shortener_clicks_decrement_atomic(transactional_db):
surl = ShortenedURL.objects.create(url=TEST_URL)
short_url = surl.shortened_url()