Skip to content

Commit

Permalink
Merge branch 'master' into elevenlabs
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Oct 3, 2023
2 parents 459e6a5 + dc59890 commit 0357e02
Show file tree
Hide file tree
Showing 30 changed files with 870 additions and 168 deletions.
29 changes: 15 additions & 14 deletions auth_backend.py → auth/auth_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,37 @@
from starlette.concurrency import run_in_threadpool

from app_users.models import AppUser
from daras_ai_v2.crypto import get_random_string, get_random_doc_id
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.db import FIREBASE_SESSION_COOKIE, ANONYMOUS_USER_COOKIE
from gooeysite.bg_db_conn import db_middleware

# quick and dirty way to bypass authentication for testing
_forced_auth_user = []
authlocal = []


@contextmanager
def force_authentication():
user = AppUser.objects.create(
is_anonymous=True, uid=get_random_doc_id(), balance=10**9
authlocal.append(
AppUser.objects.get_or_create(
email="[email protected]",
defaults=dict(is_anonymous=True, uid=get_random_doc_id(), balance=10**9),
)[0]
)
try:
_forced_auth_user.append(user)
yield
yield authlocal[0]
finally:
_forced_auth_user.clear()
authlocal.clear()


class SessionAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
return await run_in_threadpool(authenticate, conn)
if authlocal:
return AuthCredentials(["authenticated"]), authlocal[0]
return await run_in_threadpool(_authenticate, conn)


@db_middleware
def authenticate(conn):
if _forced_auth_user:
return AuthCredentials(["authenticated"]), _forced_auth_user[0]

def _authenticate(conn):
session_cookie = conn.session.get(FIREBASE_SESSION_COOKIE)
if not session_cookie:
# Session cookie is unavailable. Check if anonymous user is available.
Expand All @@ -51,7 +52,7 @@ def authenticate(conn):
# Session cookie is unavailable. Force user to login.
return AuthCredentials(), None

user = verify_session_cookie(session_cookie)
user = _verify_session_cookie(session_cookie)
if not user:
# Session cookie was invalid
conn.session.pop(FIREBASE_SESSION_COOKIE, None)
Expand All @@ -60,7 +61,7 @@ def authenticate(conn):
return AuthCredentials(["authenticated"]), user


def verify_session_cookie(firebase_cookie: str) -> UserRecord | None:
def _verify_session_cookie(firebase_cookie: str) -> UserRecord | None:
# Verify the session cookie. In this case an additional check is added to detect
# if the user's Firebase session was revoked, user deleted/disabled, etc.
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import threading

from fastapi import Header
from fastapi.exceptions import HTTPException

from app_users.models import AppUser
from auth_backend import _forced_auth_user
from auth.auth_backend import authlocal
from daras_ai_v2 import db
from daras_ai_v2.crypto import PBKDF2PasswordHasher

Expand All @@ -15,9 +17,8 @@ def api_auth_header(
description=f"{auth_keyword} $GOOEY_API_KEY",
),
) -> AppUser:
if _forced_auth_user:
return _forced_auth_user[0]

if authlocal:
return authlocal[0]
return authenticate(authorization)


Expand Down
3 changes: 2 additions & 1 deletion bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Workflow(models.IntegerChoices):
RELATED_QNA_MAKER = (27, "Related QnA Maker")
RELATED_QNA_MAKER_DOC = (28, "Related QnA Maker Doc")
EMBEDDINGS = (29, "Embeddings")
BULK_RUNNER = (30, "Bulk Runner")

@property
def short_slug(self):
Expand Down Expand Up @@ -239,7 +240,7 @@ def submit_api_call(
kwds=dict(
page_cls=Workflow(self.workflow).page_cls,
query_params=dict(
example_id=self.example_id, run_id=self.id, uid=self.uid
example_id=self.example_id, run_id=self.run_id, uid=self.uid
),
user=current_user,
request_body=request_body,
Expand Down
4 changes: 2 additions & 2 deletions bots/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
CHATML_ROLE_ASSISSTANT = "assistant"


def test_add_balance_direct():
def test_add_balance_direct(transactional_db):
pk = AppUser.objects.create(balance=0, is_anonymous=False).pk
amounts = [[random.randint(-100, 10_000) for _ in range(100)] for _ in range(5)]

Expand All @@ -28,7 +28,7 @@ def worker(amts):
assert AppUser.objects.get(pk=pk).balance == sum(map(sum, amounts))


def test_create_bot_integration_conversation_message():
def test_create_bot_integration_conversation_message(transactional_db):
# Create a new BotIntegration with WhatsApp as the platform
bot_integration = BotIntegration.objects.create(
name="My Bot Integration",
Expand Down
147 changes: 138 additions & 9 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,149 @@
import typing
from functools import wraps
from threading import Thread
from unittest.mock import patch

import pytest
from pytest_subtests import subtests

from auth import auth_backend
from celeryapp import app
from daras_ai_v2.base import BasePage

@pytest.fixture

@pytest.fixture(scope="session")
def django_db_setup(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
from django.core.management import call_command

call_command("loaddata", "fixture.json")


@pytest.fixture(
# add this fixture to all tests
autouse=True
)
def enable_db_access_for_all_tests(
# enable transactional db
transactional_db,
@pytest.fixture
def force_authentication():
with auth_backend.force_authentication() as user:
yield user


app.conf.task_always_eager = True


@pytest.fixture
def mock_gui_runner():
with patch("celeryapp.tasks.gui_runner", _mock_gui_runner):
yield


@app.task
def _mock_gui_runner(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
pass
sr = page_cls.run_doc_sr(run_id, uid)
sr.set(sr.parent.to_dict())
sr.save()


@pytest.fixture
def threadpool_subtest(subtests, max_workers: int = 8):
ts = []

def submit(fn, *args, **kwargs):
msg = "--".join(map(str, args))

@wraps(fn)
def runner(*args, **kwargs):
with subtests.test(msg=msg):
return fn(*args, **kwargs)

ts.append(Thread(target=runner, args=args, kwargs=kwargs))

yield submit

for i in range(0, len(ts), max_workers):
s = slice(i, i + max_workers)
for t in ts[s]:
t.start()
for t in ts[s]:
t.join()


# class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker):
# class _dj_db_wrapper:
# def ensure_connection(self):
# pass
#
#
# @pytest.mark.tryfirst
# def pytest_sessionstart(session):
# # make the session threadsafe
# _pytest.runner.SetupState = ThreadLocalSetupState
#
# # ensure that the fixtures (specifically finalizers) are threadsafe
# _pytest.fixtures.FixtureDef = ThreadLocalFixtureDef
#
# django.test.utils._TestState = threading.local()
#
# # make the environment threadsafe
# os.environ = ThreadLocalEnviron(os.environ)
#
#
# @pytest.hookimpl
# def pytest_runtestloop(session: _pytest.main.Session):
# if session.testsfailed and not session.config.option.continue_on_collection_errors:
# raise session.Interrupted(
# "%d error%s during collection"
# % (session.testsfailed, "s" if session.testsfailed != 1 else "")
# )
#
# if session.config.option.collectonly:
# return True
#
# num_threads = 10
#
# threadpool_items = [
# item
# for item in session.items
# if any("run_in_threadpool" in marker.name for marker in item.iter_markers())
# ]
#
# manager = pytest_django.plugin._blocking_manager
# pytest_django.plugin._blocking_manager = DummyDatabaseBlocker()
# try:
# with manager.unblock():
# for j in range(0, len(threadpool_items), num_threads):
# s = slice(j, j + num_threads)
# futs = []
# for i, item in enumerate(threadpool_items[s]):
# nextitem = (
# threadpool_items[i + 1]
# if i + 1 < len(threadpool_items)
# else None
# )
# p = threading.Thread(
# target=item.config.hook.pytest_runtest_protocol,
# kwargs=dict(item=item, nextitem=nextitem),
# )
# p.start()
# futs.append(p)
# for p in futs:
# p.join()
# if session.shouldfail:
# raise session.Failed(session.shouldfail)
# if session.shouldstop:
# raise session.Interrupted(session.shouldstop)
# finally:
# pytest_django.plugin._blocking_manager = manager
#
# session.items = [
# item
# for item in session.items
# if not any("run_in_threadpool" in marker.name for marker in item.iter_markers())
# ]
# for i, item in enumerate(session.items):
# nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
# item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
# if session.shouldfail:
# raise session.Failed(session.shouldfail)
# if session.shouldstop:
# raise session.Interrupted(session.shouldstop)
# return True
10 changes: 4 additions & 6 deletions daras_ai/image_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from daras_ai_v2 import settings

if False:
from firebase_admin import storage


def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes:
img_cv2 = bytes_to_cv2_img(img_bytes)
Expand Down Expand Up @@ -54,15 +57,10 @@ def upload_file_from_bytes(
data: bytes,
content_type: str = None,
) -> str:
from firebase_admin import storage

if not content_type:
content_type = mimetypes.guess_type(filename)[0]
content_type = content_type or "application/octet-stream"

filename = safe_filename(filename)
bucket = storage.bucket(settings.GS_BUCKET_NAME)
blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}")
blob = storage_blob_for(filename)
blob.upload_from_string(data, content_type=content_type)
return blob.public_url

Expand Down
2 changes: 2 additions & 0 deletions daras_ai_v2/all_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from bots.models import Workflow
from daras_ai_v2.base import BasePage
from recipes.BulkRunner import BulkRunnerPage
from recipes.ChyronPlant import ChyronPlantPage
from recipes.CompareLLM import CompareLLMPage
from recipes.CompareText2Img import CompareText2ImgPage
Expand Down Expand Up @@ -71,6 +72,7 @@
ImageSegmentationPage,
CompareUpscalerPage,
DocExtractPage,
BulkRunnerPage,
]

# exposed as API
Expand Down
2 changes: 1 addition & 1 deletion daras_ai_v2/api_examples_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from furl import furl

from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url
from gooey_token_authentication1.token_authentication import auth_keyword
from auth.token_authentication import auth_keyword


def get_filenames(request_body):
Expand Down
Loading

0 comments on commit 0357e02

Please sign in to comment.