diff --git a/auth_backend.py b/auth/auth_backend.py similarity index 76% rename from auth_backend.py rename to auth/auth_backend.py index 24dded058..d197ef213 100644 --- a/auth_backend.py +++ b/auth/auth_backend.py @@ -6,36 +6,37 @@ from starlette.concurrency import run_in_threadpool from app_users.models import AppUser -from daras_ai_v2.crypto import get_random_string, get_random_doc_id +from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.db import FIREBASE_SESSION_COOKIE, ANONYMOUS_USER_COOKIE from gooeysite.bg_db_conn import db_middleware # quick and dirty way to bypass authentication for testing -_forced_auth_user = [] +authlocal = [] @contextmanager def force_authentication(): - user = AppUser.objects.create( - is_anonymous=True, uid=get_random_doc_id(), balance=10**9 + authlocal.append( + AppUser.objects.get_or_create( + email="tests@pytest.org", + defaults=dict(is_anonymous=True, uid=get_random_doc_id(), balance=10**9), + )[0] ) try: - _forced_auth_user.append(user) - yield + yield authlocal[0] finally: - _forced_auth_user.clear() + authlocal.clear() class SessionAuthBackend(AuthenticationBackend): async def authenticate(self, conn): - return await run_in_threadpool(authenticate, conn) + if authlocal: + return AuthCredentials(["authenticated"]), authlocal[0] + return await run_in_threadpool(_authenticate, conn) @db_middleware -def authenticate(conn): - if _forced_auth_user: - return AuthCredentials(["authenticated"]), _forced_auth_user[0] - +def _authenticate(conn): session_cookie = conn.session.get(FIREBASE_SESSION_COOKIE) if not session_cookie: # Session cookie is unavailable. Check if anonymous user is available. @@ -51,7 +52,7 @@ def authenticate(conn): # Session cookie is unavailable. Force user to login. return AuthCredentials(), None - user = verify_session_cookie(session_cookie) + user = _verify_session_cookie(session_cookie) if not user: # Session cookie was invalid conn.session.pop(FIREBASE_SESSION_COOKIE, None) @@ -60,7 +61,7 @@ def authenticate(conn): return AuthCredentials(["authenticated"]), user -def verify_session_cookie(firebase_cookie: str) -> UserRecord | None: +def _verify_session_cookie(firebase_cookie: str) -> UserRecord | None: # Verify the session cookie. In this case an additional check is added to detect # if the user's Firebase session was revoked, user deleted/disabled, etc. try: diff --git a/gooey_token_authentication1/token_authentication.py b/auth/token_authentication.py similarity index 94% rename from gooey_token_authentication1/token_authentication.py rename to auth/token_authentication.py index 4e81c42b0..b33bbbbd0 100644 --- a/gooey_token_authentication1/token_authentication.py +++ b/auth/token_authentication.py @@ -1,8 +1,10 @@ +import threading + from fastapi import Header from fastapi.exceptions import HTTPException from app_users.models import AppUser -from auth_backend import _forced_auth_user +from auth.auth_backend import authlocal from daras_ai_v2 import db from daras_ai_v2.crypto import PBKDF2PasswordHasher @@ -15,9 +17,8 @@ def api_auth_header( description=f"{auth_keyword} $GOOEY_API_KEY", ), ) -> AppUser: - if _forced_auth_user: - return _forced_auth_user[0] - + if authlocal: + return authlocal[0] return authenticate(authorization) diff --git a/bots/models.py b/bots/models.py index b4eb715c2..0f0ddce50 100644 --- a/bots/models.py +++ b/bots/models.py @@ -68,6 +68,7 @@ class Workflow(models.IntegerChoices): RELATED_QNA_MAKER = (27, "Related QnA Maker") RELATED_QNA_MAKER_DOC = (28, "Related QnA Maker Doc") EMBEDDINGS = (29, "Embeddings") + BULK_RUNNER = (30, "Bulk Runner") @property def short_slug(self): @@ -239,7 +240,7 @@ def submit_api_call( kwds=dict( page_cls=Workflow(self.workflow).page_cls, query_params=dict( - example_id=self.example_id, run_id=self.id, uid=self.uid + example_id=self.example_id, run_id=self.run_id, uid=self.uid ), user=current_user, request_body=request_body, diff --git a/bots/tests.py b/bots/tests.py index 0bc887389..1fb66e16b 100644 --- a/bots/tests.py +++ b/bots/tests.py @@ -14,7 +14,7 @@ CHATML_ROLE_ASSISSTANT = "assistant" -def test_add_balance_direct(): +def test_add_balance_direct(transactional_db): pk = AppUser.objects.create(balance=0, is_anonymous=False).pk amounts = [[random.randint(-100, 10_000) for _ in range(100)] for _ in range(5)] @@ -28,7 +28,7 @@ def worker(amts): assert AppUser.objects.get(pk=pk).balance == sum(map(sum, amounts)) -def test_create_bot_integration_conversation_message(): +def test_create_bot_integration_conversation_message(transactional_db): # Create a new BotIntegration with WhatsApp as the platform bot_integration = BotIntegration.objects.create( name="My Bot Integration", diff --git a/conftest.py b/conftest.py index 7bea1a865..ff61671aa 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,17 @@ +import typing +from functools import wraps +from threading import Thread +from unittest.mock import patch + import pytest +from pytest_subtests import subtests +from auth import auth_backend +from celeryapp import app +from daras_ai_v2.base import BasePage -@pytest.fixture + +@pytest.fixture(scope="session") def django_db_setup(django_db_setup, django_db_blocker): with django_db_blocker.unblock(): from django.core.management import call_command @@ -9,12 +19,131 @@ def django_db_setup(django_db_setup, django_db_blocker): call_command("loaddata", "fixture.json") -@pytest.fixture( - # add this fixture to all tests - autouse=True -) -def enable_db_access_for_all_tests( - # enable transactional db - transactional_db, +@pytest.fixture +def force_authentication(): + with auth_backend.force_authentication() as user: + yield user + + +app.conf.task_always_eager = True + + +@pytest.fixture +def mock_gui_runner(): + with patch("celeryapp.tasks.gui_runner", _mock_gui_runner): + yield + + +@app.task +def _mock_gui_runner( + *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): - pass + sr = page_cls.run_doc_sr(run_id, uid) + sr.set(sr.parent.to_dict()) + sr.save() + + +@pytest.fixture +def threadpool_subtest(subtests, max_workers: int = 8): + ts = [] + + def submit(fn, *args, **kwargs): + msg = "--".join(map(str, args)) + + @wraps(fn) + def runner(*args, **kwargs): + with subtests.test(msg=msg): + return fn(*args, **kwargs) + + ts.append(Thread(target=runner, args=args, kwargs=kwargs)) + + yield submit + + for i in range(0, len(ts), max_workers): + s = slice(i, i + max_workers) + for t in ts[s]: + t.start() + for t in ts[s]: + t.join() + + +# class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker): +# class _dj_db_wrapper: +# def ensure_connection(self): +# pass +# +# +# @pytest.mark.tryfirst +# def pytest_sessionstart(session): +# # make the session threadsafe +# _pytest.runner.SetupState = ThreadLocalSetupState +# +# # ensure that the fixtures (specifically finalizers) are threadsafe +# _pytest.fixtures.FixtureDef = ThreadLocalFixtureDef +# +# django.test.utils._TestState = threading.local() +# +# # make the environment threadsafe +# os.environ = ThreadLocalEnviron(os.environ) +# +# +# @pytest.hookimpl +# def pytest_runtestloop(session: _pytest.main.Session): +# if session.testsfailed and not session.config.option.continue_on_collection_errors: +# raise session.Interrupted( +# "%d error%s during collection" +# % (session.testsfailed, "s" if session.testsfailed != 1 else "") +# ) +# +# if session.config.option.collectonly: +# return True +# +# num_threads = 10 +# +# threadpool_items = [ +# item +# for item in session.items +# if any("run_in_threadpool" in marker.name for marker in item.iter_markers()) +# ] +# +# manager = pytest_django.plugin._blocking_manager +# pytest_django.plugin._blocking_manager = DummyDatabaseBlocker() +# try: +# with manager.unblock(): +# for j in range(0, len(threadpool_items), num_threads): +# s = slice(j, j + num_threads) +# futs = [] +# for i, item in enumerate(threadpool_items[s]): +# nextitem = ( +# threadpool_items[i + 1] +# if i + 1 < len(threadpool_items) +# else None +# ) +# p = threading.Thread( +# target=item.config.hook.pytest_runtest_protocol, +# kwargs=dict(item=item, nextitem=nextitem), +# ) +# p.start() +# futs.append(p) +# for p in futs: +# p.join() +# if session.shouldfail: +# raise session.Failed(session.shouldfail) +# if session.shouldstop: +# raise session.Interrupted(session.shouldstop) +# finally: +# pytest_django.plugin._blocking_manager = manager +# +# session.items = [ +# item +# for item in session.items +# if not any("run_in_threadpool" in marker.name for marker in item.iter_markers()) +# ] +# for i, item in enumerate(session.items): +# nextitem = session.items[i + 1] if i + 1 < len(session.items) else None +# item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) +# if session.shouldfail: +# raise session.Failed(session.shouldfail) +# if session.shouldstop: +# raise session.Interrupted(session.shouldstop) +# return True diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py index 5b67d0445..96bd7734c 100644 --- a/daras_ai/image_input.py +++ b/daras_ai/image_input.py @@ -10,6 +10,9 @@ from daras_ai_v2 import settings +if False: + from firebase_admin import storage + def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes: img_cv2 = bytes_to_cv2_img(img_bytes) @@ -54,15 +57,10 @@ def upload_file_from_bytes( data: bytes, content_type: str = None, ) -> str: - from firebase_admin import storage - if not content_type: content_type = mimetypes.guess_type(filename)[0] content_type = content_type or "application/octet-stream" - - filename = safe_filename(filename) - bucket = storage.bucket(settings.GS_BUCKET_NAME) - blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}") + blob = storage_blob_for(filename) blob.upload_from_string(data, content_type=content_type) return blob.public_url diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py index 5befa1451..85047e770 100644 --- a/daras_ai_v2/all_pages.py +++ b/daras_ai_v2/all_pages.py @@ -3,6 +3,7 @@ from bots.models import Workflow from daras_ai_v2.base import BasePage +from recipes.BulkRunner import BulkRunnerPage from recipes.ChyronPlant import ChyronPlantPage from recipes.CompareLLM import CompareLLMPage from recipes.CompareText2Img import CompareText2ImgPage @@ -71,6 +72,7 @@ ImageSegmentationPage, CompareUpscalerPage, DocExtractPage, + BulkRunnerPage, ] # exposed as API diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index 7f15b06f4..da9a25a5e 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -6,7 +6,7 @@ from furl import furl from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url -from gooey_token_authentication1.token_authentication import auth_keyword +from auth.token_authentication import auth_keyword def get_filenames(request_body): diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index c0d594dd2..9744110c3 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -365,30 +365,35 @@ def get_sr_from_query_params_dict(self, query_params) -> SavedRun: example_id, run_id, uid = extract_query_params(query_params) return self.get_sr_from_query_params(example_id, run_id, uid) - def get_sr_from_query_params(self, example_id, run_id, uid) -> SavedRun: + @classmethod + def get_sr_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> SavedRun: try: if run_id and uid: - sr = self.run_doc_sr(run_id, uid) + sr = cls.run_doc_sr(run_id, uid) elif example_id: - sr = self.example_doc_sr(example_id) + sr = cls.example_doc_sr(example_id) else: - sr = self.recipe_doc_sr() + sr = cls.recipe_doc_sr() return sr except SavedRun.DoesNotExist: raise HTTPException(status_code=404) - def recipe_doc_sr(self) -> SavedRun: + @classmethod + def recipe_doc_sr(cls) -> SavedRun: return SavedRun.objects.get_or_create( - workflow=self.workflow, + workflow=cls.workflow, run_id__isnull=True, uid__isnull=True, example_id__isnull=True, )[0] + @classmethod def run_doc_sr( - self, run_id: str, uid: str, create: bool = False, parent: SavedRun = None + cls, run_id: str, uid: str, create: bool = False, parent: SavedRun = None ) -> SavedRun: - config = dict(workflow=self.workflow, uid=uid, run_id=run_id) + config = dict(workflow=cls.workflow, uid=uid, run_id=run_id) if create: return SavedRun.objects.get_or_create( **config, defaults=dict(parent=parent) @@ -396,8 +401,9 @@ def run_doc_sr( else: return SavedRun.objects.get(**config) - def example_doc_sr(self, example_id: str, create: bool = False) -> SavedRun: - config = dict(workflow=self.workflow, example_id=example_id) + @classmethod + def example_doc_sr(cls, example_id: str, create: bool = False) -> SavedRun: + config = dict(workflow=cls.workflow, example_id=example_id) if create: return SavedRun.objects.get_or_create(**config)[0] else: @@ -851,7 +857,7 @@ def _render(sr: SavedRun): workflow=self.workflow, hidden=False, example_id__isnull=False, - ).exclude()[:50] + )[:50] grid_layout(3, example_runs, _render) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 31865022b..5c3c2de59 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -1,3 +1,5 @@ +import typing + import gooey_ui as st from daras_ai_v2 import settings @@ -12,8 +14,18 @@ def is_user_uploaded_url(url: str) -> bool: def document_uploader( label: str, - key="documents", - accept=(".pdf", ".txt", ".docx", ".md", ".html", ".wav", ".ogg", ".mp3", ".aac"), + key: str = "documents", + accept: typing.Iterable[str] = ( + ".pdf", + ".txt", + ".docx", + ".md", + ".html", + ".wav", + ".ogg", + ".mp3", + ".aac", + ), ): st.write(label, className="gui-input") documents = st.session_state.get(key) or [] @@ -45,6 +57,7 @@ def document_uploader( accept=accept, accept_multiple_files=True, ) + return st.session_state.get(key, []) def doc_search_settings( diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 657d2898c..1baf64b47 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -3,12 +3,13 @@ from enum import Enum import jinja2 +from typing_extensions import TypedDict import gooey_ui from daras_ai_v2.scrollable_html_widget import scrollable_html -class SearchReference(typing.TypedDict): +class SearchReference(TypedDict): url: str title: str snippet: str @@ -19,17 +20,23 @@ class CitationStyles(Enum): number = "Numbers ( [1] [2] [3] ..)" title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)" url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)" + symbol = "Symbols ( [*] [†] [‡] ..)" markdown = "Markdown ( [Source 1](https://source1.com) [Source 2](https://source2.com) [Source 3](https://source3.com) ..)" html = "HTML ( Source 1 Source 2 Source 3 ..)" slack_mrkdwn = "Slack mrkdwn ( ..)" plaintext = "Plain Text / WhatsApp ( [Source 1 https://source1.com] [Source 2 https://source2.com] [Source 3 https://source3.com] ..)" - number_markdown = " Markdown Numbers + Footnotes" + number_markdown = "Markdown Numbers + Footnotes" number_html = "HTML Numbers + Footnotes" number_slack_mrkdwn = "Slack mrkdown Numbers + Footnotes" number_plaintext = "Plain Text / WhatsApp Numbers + Footnotes" + symbol_markdown = "Markdown Symbols + Footnotes" + symbol_html = "HTML Symbols + Footnotes" + symbol_slack_mrkdwn = "Slack mrkdown Symbols + Footnotes" + symbol_plaintext = "Plain Text / WhatsApp Symbols + Footnotes" + def remove_quotes(snippet: str) -> str: return re.sub(r"[\"\']+", r'"', snippet).strip() @@ -63,36 +70,65 @@ def apply_response_template( match citation_style: case CitationStyles.number | CitationStyles.number_plaintext: cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.number_html: + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join( + ref_to_slack_mrkdwn(ref) for ref in ref_map.values() + ) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() ) + case CitationStyles.number_markdown: cites = " ".join( markdown_link(f"[{ref_num}]", ref["url"]) for ref_num, ref in ref_map.items() ) + case CitationStyles.number_html: + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() + ) case CitationStyles.number_slack_mrkdwn: cites = " ".join( slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) for ref_num, ref in ref_map.items() ) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: + + case CitationStyles.symbol_markdown: cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - case CitationStyles.plaintext: + case CitationStyles.symbol_html: cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' + html_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() + ) + case CitationStyles.symbol_slack_mrkdwn: + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) for ref_num, ref in ref_map.items() ) case None: @@ -128,6 +164,31 @@ def apply_response_template( for ref_num, ref in sorted(all_refs.items()) ) + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + for ref_num, ref in all_refs.items(): try: template = ref["response_template"] @@ -205,3 +266,9 @@ def render_output_with_refs(state, height): for text in output_text: html = render_text_with_refs(text, state.get("references", [])) scrollable_html(html, height=height) + + +FOOTNOTE_SYMBOLS = ["*", "†", "‡", "§", "¶", "#", "♠", "♥", "♦", "♣", "✠", "☮", "☯", "✡"] # fmt: skip +def generate_footnote_symbol(idx: int) -> str: + quotient, remainder = divmod(idx, len(FOOTNOTE_SYMBOLS)) + return FOOTNOTE_SYMBOLS[remainder] * (quotient + 1) diff --git a/gooey_ui/components.py b/gooey_ui/components.py index f9b366d1c..ea3bbedda 100644 --- a/gooey_ui/components.py +++ b/gooey_ui/components.py @@ -241,7 +241,7 @@ def text_area( key = md5_values( "textarea", label, height, help, value, placeholder, label_visibility ) - value = state.session_state.setdefault(key, value) + value = str(state.session_state.setdefault(key, value)) if label_visibility != "visible": label = None if disabled: @@ -462,6 +462,10 @@ def json(value: typing.Any, expanded: bool = False, depth: int = 1): ).mount() +def data_table(file_url: str): + return _node("data-table", fileUrl=file_url) + + def table(df: "pd.DataFrame"): state.RenderTreeNode( name="table", diff --git a/poetry.lock b/poetry.lock index 32a70bdde..de91446e3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -173,24 +173,24 @@ vine = ">=5.0.0" [[package]] name = "anyio" -version = "4.0.0" +version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, - {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, + {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, + {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.22)"] +doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] +test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (<0.22)"] [[package]] name = "appdirs" @@ -1483,11 +1483,8 @@ files = [ [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" -grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] -grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""} +grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""} +grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""} protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1604,8 +1601,8 @@ files = [ google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} google-cloud-core = ">=1.4.1,<3.0.0dev" proto-plus = [ - {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, + {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -1623,8 +1620,8 @@ files = [ [package.dependencies] google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} proto-plus = [ - {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, + {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -1663,8 +1660,8 @@ files = [ [package.dependencies] google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} proto-plus = [ - {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, + {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -3395,12 +3392,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, - {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, - {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""}, - {version = ">=1.17.0", markers = "python_version >= \"3.7\""}, - {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, + {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, + {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, ] [[package]] @@ -3458,8 +3452,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4238,6 +4232,21 @@ pytest = ">=5.4.0" docs = ["sphinx", "sphinx-rtd-theme"] testing = ["Django", "django-configurations (>=2.0)"] +[[package]] +name = "pytest-subtests" +version = "0.11.0" +description = "unittest subTest() support and subtests fixture" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-subtests-0.11.0.tar.gz", hash = "sha256:51865c88457545f51fb72011942f0a3c6901ee9e24cbfb6d1b9dc1348bafbe37"}, + {file = "pytest_subtests-0.11.0-py3-none-any.whl", hash = "sha256:453389984952eec85ab0ce0c4f026337153df79587048271c7fd0f49119c07e4"}, +] + +[package.dependencies] +attrs = ">=19.2.0" +pytest = ">=7.0" + [[package]] name = "pytest-xdist" version = "3.3.1" @@ -5204,7 +5213,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -6129,4 +6138,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "e1aa198ba112e95195815327669b1e7687428ab0510c5485f057442a207b982e" +content-hash = "8cb6f5a826bc1bbd06c65a871e027cf967178cd2e8e6c33c262b80a43a772cfe" diff --git a/pyproject.toml b/pyproject.toml index 3ea6240be..6c697412c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,12 @@ pandas = "^2.0.1" google-cloud-firestore = "^2.7.0" replicate = "^0.4.0" fastapi = "^0.85.0" -uvicorn = {extras = ["standard"], version = "^0.18.3"} +uvicorn = { extras = ["standard"], version = "^0.18.3" } firebase-admin = "^6.0.0" # mediapipe for M1 macs -mediapipe-silicon = {version = "^0.8.11", markers = "platform_machine == 'arm64'", platform = "darwin"} +mediapipe-silicon = { version = "^0.8.11", markers = "platform_machine == 'arm64'", platform = "darwin" } # mediapipe for others -mediapipe = {version = "^0.8.11", markers = "platform_machine != 'arm64'"} +mediapipe = { version = "^0.8.11", markers = "platform_machine != 'arm64'" } furl = "^2.1.3" itsdangerous = "^2.1.2" pytest = "^7.2.0" @@ -53,7 +53,7 @@ llama-index = "^0.5.27" nltk = "^3.8.1" Jinja2 = "^3.1.2" Django = "^4.2" -django-phonenumber-field = {extras = ["phonenumberslite"], version = "^7.0.2"} +django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" } gunicorn = "^20.1.0" psycopg2-binary = "^2.9.6" whitenoise = "^6.4.0" @@ -75,6 +75,8 @@ tabulate = "^0.9.0" deepgram-sdk = "^2.11.0" scipy = "^1.11.2" rank-bm25 = "^0.2.2" +pytest-subtests = "^0.11.0" +anyio = "^3.4.0" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" diff --git a/pytest.ini b/pytest.ini index 55040e469..4a9863462 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,4 @@ [pytest] -addopts = --tb=native -vv -n 16 --disable-warnings +addopts = --tb=native --disable-warnings DJANGO_SETTINGS_MODULE = daras_ai_v2.settings -python_files = tests.py test_*.py *_tests.py \ No newline at end of file +python_files = tests.py test_*.py *_tests.py diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py new file mode 100644 index 000000000..5d5f9b554 --- /dev/null +++ b/recipes/BulkRunner.py @@ -0,0 +1,357 @@ +import io +import typing + +import pandas as pd +import requests +from fastapi import HTTPException +from furl import furl +from pydantic import BaseModel + +import gooey_ui as st +from bots.models import Workflow +from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.base import BasePage +from daras_ai_v2.doc_search_settings_widgets import document_uploader +from daras_ai_v2.functional import map_parallel +from daras_ai_v2.query_params_util import extract_query_params +from recipes.DocSearch import render_documents + +CACHED_COLUMNS = "__cached_columns" + + +class BulkRunnerPage(BasePage): + title = "Bulk Runner" + workflow = Workflow.BULK_RUNNER + slug_versions = ["bulk-runner", "bulk"] + + class RequestModel(BaseModel): + documents: list[str] + run_urls: list[str] + + input_columns: dict[str, str] + output_columns: dict[str, str] + + class ResponseModel(BaseModel): + output_documents: list[str] + + def fields_to_save(self) -> [str]: + return super().fields_to_save() + [CACHED_COLUMNS] + + def render_form_v2(self): + from daras_ai_v2.all_pages import page_slug_map, normalize_slug + + run_urls = st.session_state.get("run_urls", "") + st.session_state.setdefault("__run_urls", "\n".join(run_urls)) + run_urls = ( + st.text_area("##### Run URL(s)", key="__run_urls").strip().splitlines() + ) + st.session_state["run_urls"] = run_urls + + files = document_uploader( + "##### Upload a File", + accept=(".csv", ".xlsx", ".xls", ".json", ".tsv", ".xml"), + ) + + if files: + st.session_state[CACHED_COLUMNS] = list( + { + col: None + for df in map_parallel(_read_df, files) + for col in df.columns + if not col.startswith("Unnamed:") + } + ) + else: + st.session_state.pop(CACHED_COLUMNS, None) + + required_input_fields = {} + optional_input_fields = {} + output_fields = {} + + for url in run_urls: + f = furl(url) + slug = f.path.segments[0] + try: + page_cls = page_slug_map[normalize_slug(slug)] + except KeyError as e: + st.error(repr(e)) + continue + + example_id, run_id, uid = extract_query_params(f.query.params) + try: + sr = page_cls.get_sr_from_query_params(example_id, run_id, uid) + except HTTPException as e: + st.error(repr(e)) + continue + + schema = page_cls.RequestModel.schema(ref_template="{model}") + for field, model_field in page_cls.RequestModel.__fields__.items(): + if model_field.required: + input_fields = required_input_fields + else: + input_fields = optional_input_fields + field_props = schema["properties"][field] + title = field_props["title"] + keys = None + if is_arr(field_props): + try: + ref = field_props["items"]["$ref"] + props = schema["definitions"][ref]["properties"] + keys = {k: prop["title"] for k, prop in props.items()} + except KeyError: + try: + keys = {k: k for k in sr.state[field][0].keys()} + except (KeyError, IndexError, AttributeError): + pass + elif field_props.get("type") == "object": + try: + keys = {k: k for k in sr.state[field].keys()} + except (KeyError, AttributeError): + pass + if keys: + for k, ktitle in keys.items(): + input_fields[f"{field}.{k}"] = f"{title}.{ktitle}" + else: + input_fields[field] = title + + schema = page_cls.ResponseModel.schema() + output_fields |= { + field: schema["properties"][field]["title"] + for field, model_field in page_cls.ResponseModel.__fields__.items() + } + + columns = st.session_state.get(CACHED_COLUMNS, []) + if not columns: + return + + for file in files: + st.data_table(file) + + if not (required_input_fields or optional_input_fields): + return + + col1, col2 = st.columns(2) + + with col1: + st.write("##### Input Columns") + + input_columns_old = st.session_state.pop("input_columns", {}) + input_columns_new = st.session_state.setdefault("input_columns", {}) + + column_options = [None, *columns] + for fields in (required_input_fields, optional_input_fields): + for field, title in fields.items(): + col = st.selectbox( + label="`" + title + "`", + options=column_options, + key="--input-mapping:" + field, + default_value=input_columns_old.get(field), + ) + if col: + input_columns_new[field] = col + st.write("---") + + with col2: + st.write("##### Output Columns") + + output_columns_old = st.session_state.pop("output_columns", {}) + output_columns_new = st.session_state.setdefault("output_columns", {}) + + prev_fields = st.session_state.get("--prev-output-fields") + fields = {**output_fields, "error_msg": "Error Msg", "run_url": "Run URL"} + did_change = prev_fields is not None and prev_fields != fields + st.session_state["--prev-output-fields"] = fields + for field, title in fields.items(): + col = st.text_input( + label="`" + title + "`", + key="--output-mapping:" + field, + value=output_columns_old.get(field, title if did_change else None), + ) + if col: + output_columns_new[field] = col + + def render_example(self, state: dict): + render_documents(state) + + def render_output(self): + files = st.session_state.get("output_documents", []) + for file in files: + st.write(file) + st.data_table(file) + + def run_v2( + self, + request: "BulkRunnerPage.RequestModel", + response: "BulkRunnerPage.ResponseModel", + ) -> typing.Iterator[str | None]: + response.output_documents = [] + + for doc_ix, doc in enumerate(request.documents): + df = _read_df(doc) + in_recs = df.to_dict(orient="records") + out_recs = [] + + f = upload_file_from_bytes( + filename=f"bulk-runner-{doc_ix}-0-0.csv", + data=df.to_csv(index=False).encode(), + content_type="text/csv", + ) + response.output_documents.append(f) + + df_slices = list(slice_request_df(df, request)) + for slice_ix, (df_ix, arr_len) in enumerate(df_slices): + rec_ix = len(out_recs) + out_recs.extend(in_recs[df_ix : df_ix + arr_len]) + + for url_ix, f, request_body, page_cls in build_requests_for_df( + df, request, df_ix, arr_len + ): + progress = round( + (slice_ix + url_ix) + / (len(df_slices) + len(request.run_urls)) + * 100 + ) + yield f"{progress}%" + + example_id, run_id, uid = extract_query_params(f.query.params) + sr = page_cls.get_sr_from_query_params(example_id, run_id, uid) + + result, sr = sr.submit_api_call( + current_user=self.request.user, request_body=request_body + ) + result.get(disable_sync_subtasks=False) + sr.refresh_from_db() + state = sr.to_dict() + state["run_url"] = sr.get_app_url() + state["error_msg"] = sr.error_msg + + for field, col in request.output_columns.items(): + if len(request.run_urls) > 1: + col = f"({url_ix + 1}) {col}" + out_val = state.get(field) + if isinstance(out_val, list): + for arr_ix, item in enumerate(out_val): + if len(out_recs) <= rec_ix + arr_ix: + out_recs.append({}) + if isinstance(item, dict): + for key, val in item.items(): + out_recs[rec_ix + arr_ix][f"{col}.{key}"] = str( + val + ) + else: + out_recs[rec_ix + arr_ix][col] = str(item) + elif isinstance(out_val, dict): + for key, val in out_val.items(): + if isinstance(val, list): + for arr_ix, item in enumerate(val): + if len(out_recs) <= rec_ix + arr_ix: + out_recs.append({}) + out_recs[rec_ix + arr_ix][f"{col}.{key}"] = str( + item + ) + else: + out_recs[rec_ix][f"{col}.{key}"] = str(val) + else: + out_recs[rec_ix][col] = str(out_val) + + out_df = pd.DataFrame.from_records(out_recs) + f = upload_file_from_bytes( + filename=f"bulk-runner-{doc_ix}-{url_ix}-{df_ix}.csv", + data=out_df.to_csv(index=False).encode(), + content_type="text/csv", + ) + response.output_documents[doc_ix] = f + + +def build_requests_for_df(df, request, df_ix, arr_len): + from daras_ai_v2.all_pages import page_slug_map, normalize_slug + + for url_ix, url in enumerate(request.run_urls): + f = furl(url) + slug = f.path.segments[0] + page_cls = page_slug_map[normalize_slug(slug)] + schema = page_cls.RequestModel.schema() + properties = schema["properties"] + + request_body = {} + for field, col in request.input_columns.items(): + parts = field.split(".") + field_props = properties.get(parts[0]) or properties.get(parts) + if is_arr(field_props): + arr = request_body.setdefault(parts[0], []) + for arr_ix in range(arr_len): + value = df.at[df_ix + arr_ix, col] + if len(parts) > 1: + if len(arr) <= arr_ix: + arr.append({}) + arr[arr_ix][parts[1]] = value + else: + if len(arr) <= arr_ix: + arr.append(None) + arr[arr_ix] = value + elif len(parts) > 1 and field_props.get("type") == "object": + obj = request_body.setdefault(parts[0], {}) + obj[parts[1]] = df.at[df_ix, col] + else: + request_body[field] = df.at[df_ix, col] + # for validation + request_body = page_cls.RequestModel.parse_obj(request_body).dict() + + yield url_ix, f, request_body, page_cls + + +def slice_request_df(df, request): + from daras_ai_v2.all_pages import page_slug_map, normalize_slug + + non_array_cols = set() + for url_ix, url in enumerate(request.run_urls): + f = furl(url) + slug = f.path.segments[0] + page_cls = page_slug_map[normalize_slug(slug)] + schema = page_cls.RequestModel.schema() + properties = schema["properties"] + + for field, col in request.input_columns.items(): + if is_arr(properties.get(field.split(".")[0])): + non_array_cols.add(col) + non_array_df = df[list(non_array_cols)] + + df_ix = 0 + while df_ix < len(df): + arr_len = 1 + while df_ix + arr_len < len(df): + if not non_array_df.iloc[df_ix + arr_len].isnull().all(): + break + arr_len += 1 + yield df_ix, arr_len + df_ix += arr_len + + +def is_arr(field_props: dict) -> bool: + try: + return field_props["type"] == "array" + except KeyError: + for props in field_props.get("anyOf", []): + if props["type"] == "array": + return True + return False + + +def _read_df(f: str) -> "pd.DataFrame": + import pandas as pd + + r = requests.get(f) + r.raise_for_status() + if f.endswith(".csv"): + df = pd.read_csv(io.StringIO(r.text)) + elif f.endswith(".xlsx") or f.endswith(".xls"): + df = pd.read_excel(io.BytesIO(r.content)) + elif f.endswith(".json"): + df = pd.read_json(io.StringIO(r.text)) + elif f.endswith(".tsv"): + df = pd.read_csv(io.StringIO(r.text), sep="\t") + elif f.endswith(".xml"): + df = pd.read_xml(io.StringIO(r.text)) + else: + raise ValueError(f"Unsupported file type: {f}") + return df.dropna(how="all", axis=1).dropna(how="all", axis=0) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index a30396907..f523ba14e 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -3,6 +3,7 @@ from django.db.models import TextChoices from pydantic import BaseModel +from typing_extensions import TypedDict import gooey_ui as st from app_users.models import AppUser @@ -21,7 +22,7 @@ class AnimationModels(TextChoices): epicdream = ("epicdream.safetensors", "epiCDream (epinikion)") -class _AnimationPrompt(typing.TypedDict): +class _AnimationPrompt(TypedDict): frame: str prompt: str diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 4c8895901..d5f405ed6 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -244,6 +244,7 @@ class ResponseModel(BaseModel): final_prompt: str raw_input_text: str | None raw_output_text: list[str] | None + raw_tts_text: list[str] | None output_text: list[str] # tts @@ -734,11 +735,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: source_language="en", target_language=request.user_language, ) - - tts_text = [ - "".join(snippet for snippet, _ in parse_refs(text, references)) - for text in output_text - ] + state["raw_tts_text"] = [ + "".join(snippet for snippet, _ in parse_refs(text, references)) + for text in output_text + ] if references: citation_style = ( @@ -754,7 +754,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if not request.tts_provider: return tts_state = dict(state) - for text in tts_text: + for text in state.get("raw_tts_text", state["raw_output_text"]): tts_state["text_prompt"] = text yield from TextToSpeechPage().run(tts_state) state["output_audio"].append(tts_state["audio_url"]) diff --git a/routers/api.py b/routers/api.py index 47b7989a1..62514cc14 100644 --- a/routers/api.py +++ b/routers/api.py @@ -29,7 +29,7 @@ BasePage, StateKeys, ) -from gooey_token_authentication1.token_authentication import api_auth_header +from auth.token_authentication import api_auth_header app = APIRouter() diff --git a/routers/root.py b/routers/root.py index 386cc9813..42627572e 100644 --- a/routers/root.py +++ b/routers/root.py @@ -21,9 +21,6 @@ import gooey_ui as st from app_users.models import AppUser -from auth_backend import ( - FIREBASE_SESSION_COOKIE, -) from daras_ai.image_input import upload_file_from_bytes, safe_filename from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages, normalize_slug, page_slug_map @@ -32,6 +29,7 @@ RedirectException, ) from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts +from daras_ai_v2.db import FIREBASE_SESSION_COOKIE from daras_ai_v2.meta_content import build_meta_tags from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates diff --git a/routers/slack.py b/routers/slack.py index 75ed8465e..20e19fec6 100644 --- a/routers/slack.py +++ b/routers/slack.py @@ -14,7 +14,6 @@ from bots.tasks import create_personal_channels_for_all_members from daras_ai_v2 import settings from daras_ai_v2.bots import _on_msg, request_json, request_urlencoded_body -from daras_ai_v2.search_ref import parse_refs from daras_ai_v2.slack_bot import ( SlackBot, invite_bot_account_to_channel, @@ -23,8 +22,6 @@ fetch_user_info, parse_slack_response, ) -from gooey_token_authentication1.token_authentication import auth_keyword -from recipes.VideoBots import VideoBotsPage router = APIRouter() @@ -368,7 +365,6 @@ def slack_auth_header( @router.get("/__/slack/get-response-for-msg/{msg_id}/") def slack_get_response_for_msg_id( msg_id: str, - remove_refs: bool = True, slack_user: dict = Depends(slack_auth_header), ): try: @@ -384,11 +380,13 @@ def slack_get_response_for_msg_id( ): raise HTTPException(403, "Not authorized") - output_text = response_msg.saved_run.state.get("output_text") + state = response_msg.saved_run.state + output_text = ( + state.get("raw_tts_text") + or state.get("raw_output_text") + or state.get("output_text") + ) if not output_text: return {"status": "no_output"} - content = output_text[0] - if remove_refs: - content = "".join(snippet for snippet, _ in parse_refs(content, [])) - return {"status": "ok", "content": content} + return {"status": "ok", "content": output_text[0]} diff --git a/scripts/create_fixture.py b/scripts/create_fixture.py index d84495221..9482cef64 100644 --- a/scripts/create_fixture.py +++ b/scripts/create_fixture.py @@ -1,9 +1,24 @@ -from django.core.management import call_command +import sys + +from django.core import serializers from bots.models import SavedRun def run(): - qs = SavedRun.objects.filter(run_id__isnull=True).values_list("pk", flat=True) - pks = ",".join(map(str, qs)) - call_command("dumpdata", "bots.SavedRun", "--pks", pks, "--output", "fixture.json") + qs = SavedRun.objects.filter(run_id__isnull=True) + with open("fixture.json", "w") as f: + serializers.serialize( + "json", + get_objects(qs), + indent=2, + stream=f, + progress_output=sys.stdout, + object_count=qs.count(), + ) + + +def get_objects(qs): + for obj in qs: + obj.parent = None + yield obj diff --git a/server.py b/server.py index 1e6d03e1e..078b1508b 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,5 @@ +import logging + import anyio from decouple import config @@ -18,7 +20,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware -from auth_backend import ( +from auth.auth_backend import ( SessionAuthBackend, ) from daras_ai_v2 import settings @@ -67,10 +69,14 @@ async def health(): @app.add_middleware def request_time_middleware(app): + logger = logging.getLogger("uvicorn.time") + async def middleware(scope, receive, send): start_time = time() await app(scope, receive, send) response_time = (time() - start_time) * 1000 - print(f"{scope.get('method')} {scope.get('path')} - {response_time:.3f} ms") + logger.info( + f"{scope.get('method')} {scope.get('path')} - {response_time:.3f} ms" + ) return middleware diff --git a/tests/test_apis.py b/tests/test_apis.py index d3e1d8b7d..6412e9eda 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -1,10 +1,10 @@ import typing import pytest -from fastapi.testclient import TestClient +from starlette.testclient import TestClient -from auth_backend import force_authentication -from daras_ai_v2 import db +from auth.auth_backend import force_authentication +from bots.models import SavedRun, Workflow from daras_ai_v2.all_pages import all_test_pages from daras_ai_v2.base import ( BasePage, @@ -12,21 +12,77 @@ ) from server import app +MAX_WORKERS = 20 + client = TestClient(app) -@pytest.mark.parametrize("page_cls", all_test_pages) -def test_apis_basic(page_cls: typing.Type[BasePage]): - page = page_cls() - state = db.get_or_create_doc(db.get_doc_ref(page.slug_versions[0])).to_dict() +@pytest.mark.django_db +def test_apis_sync(mock_gui_runner, force_authentication, threadpool_subtest): + for page_cls in all_test_pages: + threadpool_subtest(_test_api_sync, page_cls) + + +def _test_api_sync(page_cls: typing.Type[BasePage]): + state = page_cls.recipe_doc_sr().state + r = client.post( + f"/v2/{page_cls.slug_versions[0]}/", + json=get_example_request_body(page_cls.RequestModel, state), + headers={"Authorization": f"Token None"}, + allow_redirects=False, + ) + assert r.status_code == 200, r.text + + +@pytest.mark.django_db +def test_apis_async(mock_gui_runner, force_authentication, threadpool_subtest): + for page_cls in all_test_pages: + threadpool_subtest(_test_api_async, page_cls) + + +def _test_api_async(page_cls: typing.Type[BasePage]): + state = page_cls.recipe_doc_sr().state + + r = client.post( + f"/v3/{page_cls.slug_versions[0]}/async/", + json=get_example_request_body(page_cls.RequestModel, state), + headers={"Authorization": f"Token None"}, + allow_redirects=False, + ) + assert r.status_code == 202, r.text + + status_url = r.json()["status_url"] + + r = client.get( + status_url, + headers={"Authorization": f"Token None"}, + allow_redirects=False, + ) + assert r.status_code == 200, r.text + + data = r.json() + assert data.get("status") == "completed", data + assert data.get("output") is not None, data + + +@pytest.mark.django_db +def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest): + for page in all_test_pages: + for sr in SavedRun.objects.filter( + workflow=page.workflow, + hidden=False, + example_id__isnull=False, + ): + threadpool_subtest(_test_apis_examples, sr) - with force_authentication(): - r = client.post( - page.endpoint, - json=get_example_request_body(page.RequestModel, state), - headers={"Authorization": f"Token None"}, - allow_redirects=False, - ) - print(r.content) - assert r.status_code == 200 +def _test_apis_examples(sr: SavedRun): + state = sr.state + page_cls = Workflow(sr.workflow).page_cls + r = client.post( + f"/v2/{page_cls.slug_versions[0]}/?example_id={sr.example_id}", + json=get_example_request_body(page_cls.RequestModel, state), + headers={"Authorization": f"Token None"}, + allow_redirects=False, + ) + assert r.status_code == 200, r.text diff --git a/tests/test_checkout.py b/tests/test_checkout.py index 5ea7c326b..5e822b775 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -1,7 +1,7 @@ import pytest from fastapi.testclient import TestClient -from auth_backend import force_authentication +from auth.auth_backend import force_authentication from routers.billing import available_subscriptions from server import app @@ -9,7 +9,7 @@ @pytest.mark.parametrize("subscription", available_subscriptions.keys()) -def test_create_checkout_session(subscription: str): +def test_create_checkout_session(subscription: str, transactional_db): with force_authentication(): form_data = {"lookup_key": subscription} if subscription == "addon": diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py index 009fb836c..02006b92f 100644 --- a/tests/test_public_endpoints.py +++ b/tests/test_public_endpoints.py @@ -1,20 +1,20 @@ import pytest -import requests -from furl import furl from starlette.routing import Route from starlette.testclient import TestClient -from auth_backend import force_authentication -from daras_ai_v2 import settings +from bots.models import SavedRun from daras_ai_v2.all_pages import all_api_pages from daras_ai_v2.tabs_widget import MenuTabs from routers import facebook +from routers.slack import slack_connect_redirect_shortcuts, slack_connect_redirect from server import app client = TestClient(app) excluded_endpoints = [ facebook.fb_webhook_verify.__name__, # gives 403 + slack_connect_redirect.__name__, + slack_connect_redirect_shortcuts.__name__, "get_run_status", # needs query params ] @@ -30,10 +30,10 @@ ] +@pytest.mark.django_db @pytest.mark.parametrize("path", route_paths) def test_all_get(path): r = client.get(path, allow_redirects=False) - print(r.content) assert r.ok @@ -41,14 +41,29 @@ def test_all_get(path): tabs = list(MenuTabs.paths.values()) -@pytest.mark.parametrize("tab", tabs) +@pytest.mark.django_db @pytest.mark.parametrize("slug", page_slugs) +@pytest.mark.parametrize("tab", tabs) def test_page_slugs(slug, tab): - with force_authentication(): - r = requests.post( - str(furl(settings.API_BASE_URL) / slug / tab), - json={}, - ) - # r = client.post(os.path.join(slug, tab), json={}, allow_redirects=True) - print(r.content) + r = client.post( + f"/{slug}/{tab}", + json={}, + allow_redirects=True, + ) assert r.status_code == 200 + + +@pytest.mark.django_db +def test_example_slugs(subtests): + for page_cls in all_api_pages: + for tab in tabs: + for example_id in SavedRun.objects.filter( + workflow=page_cls.workflow, + hidden=False, + example_id__isnull=False, + ).values_list("example_id", flat=True): + slug = page_cls.slug_versions[0] + url = f"/{slug}/{tab}?example_id={example_id}" + with subtests.test(msg=url): + r = client.post(url, json={}, allow_redirects=True) + assert r.status_code == 200 diff --git a/tests/test_search_refs.py b/tests/test_search_refs.py index 5374db260..00698bf88 100644 --- a/tests/test_search_refs.py +++ b/tests/test_search_refs.py @@ -1,4 +1,6 @@ -from daras_ai_v2.search_ref import parse_refs +import pytest + +from daras_ai_v2.search_ref import parse_refs, generate_footnote_symbol def test_ref_parser(): @@ -126,3 +128,21 @@ def test_ref_parser(): }, ), ] + + +def test_generate_footnote_symbol(): + assert generate_footnote_symbol(0) == "*" + assert generate_footnote_symbol(1) == "†" + assert generate_footnote_symbol(13) == "✡" + assert generate_footnote_symbol(14) == "**" + assert generate_footnote_symbol(15) == "††" + assert generate_footnote_symbol(27) == "✡✡" + assert generate_footnote_symbol(28) == "***" + assert generate_footnote_symbol(29) == "†††" + assert generate_footnote_symbol(41) == "✡✡✡" + assert generate_footnote_symbol(70) == "******" + assert generate_footnote_symbol(71) == "††††††" + + # testing with non-integer index + with pytest.raises(TypeError): + generate_footnote_symbol(1.5) diff --git a/tests/test_slack.py b/tests/test_slack.py index ebe72ca96..3a5802c8d 100644 --- a/tests/test_slack.py +++ b/tests/test_slack.py @@ -33,6 +33,7 @@ def test_slack_get_response_for_msg_id(transactional_db): platform_msg_id="response-msg", saved_run=SavedRun.objects.create( state=VideoBotsPage.ResponseModel( + raw_tts_text=["hello, world!"], output_text=["hello, world! [2]"], final_prompt="", output_audio=[], diff --git a/tests/test_translation.py b/tests/test_translation.py index 1079036dd..95cf2a539 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -1,21 +1,19 @@ -import pytest - from daras_ai_v2.asr import run_google_translate TRANSLATION_TESTS = [ # hindi romanized ( "Hi Sir Mera khet me mircha ke ped me fal gal Kar gir hai to iske liye ham kon sa dawa de please help me", - "Hi sir the fruits of chilli tree in my field have rotted and fallen so what medicine should we give for this please help", + "hi sir the fruit of the chilli tree in my field has rotted and fallen so what medicine should we give for this please help", ), ( "Mirchi ka ped", - "chili tree", + "chilli tree", ), # hindi ( "ान का नर्सरी खेत में रोकने के लिए कितने दिन में तैयार हो जाता है", - "in how many days the corn nursery is ready to stop in the field", + "in how many days does the seed nursery become ready to be planted in the field?", ), # telugu ( @@ -44,8 +42,12 @@ ] -@pytest.mark.parametrize("text, expected", TRANSLATION_TESTS) -def test_run_google_translate(text: str, expected: str): +def test_run_google_translate(threadpool_subtest): + for text, expected in TRANSLATION_TESTS: + threadpool_subtest(_test_run_google_translate, text, expected) + + +def _test_run_google_translate(text: str, expected: str): actual = run_google_translate([text], "en")[0] assert ( actual.replace(".", "").replace(",", "").strip().lower() diff --git a/url_shortener/tests.py b/url_shortener/tests.py index 5e028329a..2589434f1 100644 --- a/url_shortener/tests.py +++ b/url_shortener/tests.py @@ -9,14 +9,14 @@ client = TestClient(app) -def test_url_shortener(): +def test_url_shortener(transactional_db): surl = ShortenedURL.objects.create(url=TEST_URL) short_url = surl.shortened_url() r = client.get(short_url, allow_redirects=False) assert r.is_redirect and r.headers["location"] == TEST_URL -def test_url_shortener_max_clicks(): +def test_url_shortener_max_clicks(transactional_db): surl = ShortenedURL.objects.create(url=TEST_URL, max_clicks=5) short_url = surl.shortened_url() for _ in range(5): @@ -26,14 +26,14 @@ def test_url_shortener_max_clicks(): assert r.status_code == 410 -def test_url_shortener_disabled(): +def test_url_shortener_disabled(transactional_db): surl = ShortenedURL.objects.create(url=TEST_URL, disabled=True) short_url = surl.shortened_url() r = client.get(short_url, allow_redirects=False) assert r.status_code == 410 -def test_url_shortener_create_atomic(): +def test_url_shortener_create_atomic(transactional_db): def create(_): return [ ShortenedURL.objects.create(url=TEST_URL).shortened_url() @@ -43,7 +43,7 @@ def create(_): assert len(set(flatmap_parallel(create, range(5)))) == 500 -def test_url_shortener_clicks_decrement_atomic(): +def test_url_shortener_clicks_decrement_atomic(transactional_db): surl = ShortenedURL.objects.create(url=TEST_URL) short_url = surl.shortened_url()