-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into elevenlabs
- Loading branch information
Showing
30 changed files
with
870 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,36 +6,37 @@ | |
from starlette.concurrency import run_in_threadpool | ||
|
||
from app_users.models import AppUser | ||
from daras_ai_v2.crypto import get_random_string, get_random_doc_id | ||
from daras_ai_v2.crypto import get_random_doc_id | ||
from daras_ai_v2.db import FIREBASE_SESSION_COOKIE, ANONYMOUS_USER_COOKIE | ||
from gooeysite.bg_db_conn import db_middleware | ||
|
||
# quick and dirty way to bypass authentication for testing | ||
_forced_auth_user = [] | ||
authlocal = [] | ||
|
||
|
||
@contextmanager | ||
def force_authentication(): | ||
user = AppUser.objects.create( | ||
is_anonymous=True, uid=get_random_doc_id(), balance=10**9 | ||
authlocal.append( | ||
AppUser.objects.get_or_create( | ||
email="[email protected]", | ||
defaults=dict(is_anonymous=True, uid=get_random_doc_id(), balance=10**9), | ||
)[0] | ||
) | ||
try: | ||
_forced_auth_user.append(user) | ||
yield | ||
yield authlocal[0] | ||
finally: | ||
_forced_auth_user.clear() | ||
authlocal.clear() | ||
|
||
|
||
class SessionAuthBackend(AuthenticationBackend): | ||
async def authenticate(self, conn): | ||
return await run_in_threadpool(authenticate, conn) | ||
if authlocal: | ||
return AuthCredentials(["authenticated"]), authlocal[0] | ||
return await run_in_threadpool(_authenticate, conn) | ||
|
||
|
||
@db_middleware | ||
def authenticate(conn): | ||
if _forced_auth_user: | ||
return AuthCredentials(["authenticated"]), _forced_auth_user[0] | ||
|
||
def _authenticate(conn): | ||
session_cookie = conn.session.get(FIREBASE_SESSION_COOKIE) | ||
if not session_cookie: | ||
# Session cookie is unavailable. Check if anonymous user is available. | ||
|
@@ -51,7 +52,7 @@ def authenticate(conn): | |
# Session cookie is unavailable. Force user to login. | ||
return AuthCredentials(), None | ||
|
||
user = verify_session_cookie(session_cookie) | ||
user = _verify_session_cookie(session_cookie) | ||
if not user: | ||
# Session cookie was invalid | ||
conn.session.pop(FIREBASE_SESSION_COOKIE, None) | ||
|
@@ -60,7 +61,7 @@ def authenticate(conn): | |
return AuthCredentials(["authenticated"]), user | ||
|
||
|
||
def verify_session_cookie(firebase_cookie: str) -> UserRecord | None: | ||
def _verify_session_cookie(firebase_cookie: str) -> UserRecord | None: | ||
# Verify the session cookie. In this case an additional check is added to detect | ||
# if the user's Firebase session was revoked, user deleted/disabled, etc. | ||
try: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,149 @@ | ||
import typing | ||
from functools import wraps | ||
from threading import Thread | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from pytest_subtests import subtests | ||
|
||
from auth import auth_backend | ||
from celeryapp import app | ||
from daras_ai_v2.base import BasePage | ||
|
||
@pytest.fixture | ||
|
||
@pytest.fixture(scope="session") | ||
def django_db_setup(django_db_setup, django_db_blocker): | ||
with django_db_blocker.unblock(): | ||
from django.core.management import call_command | ||
|
||
call_command("loaddata", "fixture.json") | ||
|
||
|
||
@pytest.fixture( | ||
# add this fixture to all tests | ||
autouse=True | ||
) | ||
def enable_db_access_for_all_tests( | ||
# enable transactional db | ||
transactional_db, | ||
@pytest.fixture | ||
def force_authentication(): | ||
with auth_backend.force_authentication() as user: | ||
yield user | ||
|
||
|
||
app.conf.task_always_eager = True | ||
|
||
|
||
@pytest.fixture | ||
def mock_gui_runner(): | ||
with patch("celeryapp.tasks.gui_runner", _mock_gui_runner): | ||
yield | ||
|
||
|
||
@app.task | ||
def _mock_gui_runner( | ||
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs | ||
): | ||
pass | ||
sr = page_cls.run_doc_sr(run_id, uid) | ||
sr.set(sr.parent.to_dict()) | ||
sr.save() | ||
|
||
|
||
@pytest.fixture | ||
def threadpool_subtest(subtests, max_workers: int = 8): | ||
ts = [] | ||
|
||
def submit(fn, *args, **kwargs): | ||
msg = "--".join(map(str, args)) | ||
|
||
@wraps(fn) | ||
def runner(*args, **kwargs): | ||
with subtests.test(msg=msg): | ||
return fn(*args, **kwargs) | ||
|
||
ts.append(Thread(target=runner, args=args, kwargs=kwargs)) | ||
|
||
yield submit | ||
|
||
for i in range(0, len(ts), max_workers): | ||
s = slice(i, i + max_workers) | ||
for t in ts[s]: | ||
t.start() | ||
for t in ts[s]: | ||
t.join() | ||
|
||
|
||
# class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker): | ||
# class _dj_db_wrapper: | ||
# def ensure_connection(self): | ||
# pass | ||
# | ||
# | ||
# @pytest.mark.tryfirst | ||
# def pytest_sessionstart(session): | ||
# # make the session threadsafe | ||
# _pytest.runner.SetupState = ThreadLocalSetupState | ||
# | ||
# # ensure that the fixtures (specifically finalizers) are threadsafe | ||
# _pytest.fixtures.FixtureDef = ThreadLocalFixtureDef | ||
# | ||
# django.test.utils._TestState = threading.local() | ||
# | ||
# # make the environment threadsafe | ||
# os.environ = ThreadLocalEnviron(os.environ) | ||
# | ||
# | ||
# @pytest.hookimpl | ||
# def pytest_runtestloop(session: _pytest.main.Session): | ||
# if session.testsfailed and not session.config.option.continue_on_collection_errors: | ||
# raise session.Interrupted( | ||
# "%d error%s during collection" | ||
# % (session.testsfailed, "s" if session.testsfailed != 1 else "") | ||
# ) | ||
# | ||
# if session.config.option.collectonly: | ||
# return True | ||
# | ||
# num_threads = 10 | ||
# | ||
# threadpool_items = [ | ||
# item | ||
# for item in session.items | ||
# if any("run_in_threadpool" in marker.name for marker in item.iter_markers()) | ||
# ] | ||
# | ||
# manager = pytest_django.plugin._blocking_manager | ||
# pytest_django.plugin._blocking_manager = DummyDatabaseBlocker() | ||
# try: | ||
# with manager.unblock(): | ||
# for j in range(0, len(threadpool_items), num_threads): | ||
# s = slice(j, j + num_threads) | ||
# futs = [] | ||
# for i, item in enumerate(threadpool_items[s]): | ||
# nextitem = ( | ||
# threadpool_items[i + 1] | ||
# if i + 1 < len(threadpool_items) | ||
# else None | ||
# ) | ||
# p = threading.Thread( | ||
# target=item.config.hook.pytest_runtest_protocol, | ||
# kwargs=dict(item=item, nextitem=nextitem), | ||
# ) | ||
# p.start() | ||
# futs.append(p) | ||
# for p in futs: | ||
# p.join() | ||
# if session.shouldfail: | ||
# raise session.Failed(session.shouldfail) | ||
# if session.shouldstop: | ||
# raise session.Interrupted(session.shouldstop) | ||
# finally: | ||
# pytest_django.plugin._blocking_manager = manager | ||
# | ||
# session.items = [ | ||
# item | ||
# for item in session.items | ||
# if not any("run_in_threadpool" in marker.name for marker in item.iter_markers()) | ||
# ] | ||
# for i, item in enumerate(session.items): | ||
# nextitem = session.items[i + 1] if i + 1 < len(session.items) else None | ||
# item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) | ||
# if session.shouldfail: | ||
# raise session.Failed(session.shouldfail) | ||
# if session.shouldstop: | ||
# raise session.Interrupted(session.shouldstop) | ||
# return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.