diff --git a/README.md b/README.md index 2b475ed24..43d69a8e4 100644 --- a/README.md +++ b/README.md @@ -88,13 +88,28 @@ export MAGICK_HOME=/opt/homebrew Clone [gooey-gui](https://github.com/GooeyAI/gooey-gui) repo, in the same directory as `gooey-server` and follow the setup steps. -## Running Tests +### Run Tests ``` +ulimit -n unlimited # Increase the number of open files allowed ./scripts/run-tests.sh ``` -(If you run into issues with the number of open files, you can remove the limit with `ulimit -n unlimited`) +### Initialize databse + + +```bash +# reset the database +./manage.py reset_db -c +# create the database +./manage.py sqlcreate | psql postgres +# run migrations +./manage.py migrate +# load the fixture (donwloaded by ./scripts/run-tests.sh) +./manage.py loaddata fixture.json +# create a superuser to access admin +./manage.py createsuperuser +``` ## Run diff --git a/scripts/create_fixture.py b/scripts/create_fixture.py index 61a317d7a..0e48a6eef 100644 --- a/scripts/create_fixture.py +++ b/scripts/create_fixture.py @@ -1,14 +1,25 @@ import sys from django.core import serializers +from django.db.models import NOT_PROVIDED from app_users.models import AppUser from bots.models import BotIntegration, PublishedRun, PublishedRunVisibility +USER_FIELDS = { + "id", + "is_anonymous", + "uid", + "display_name", + "balance", +} -def run(): + +def run(*args): with open("fixture.json", "w") as f: - objs = list(filter(None, get_objects())) + objs = { + f"{type(obj)}:{obj.pk}": obj for obj in filter(None, get_objects(*args)) + }.values() print(f"Exporting {len(objs)} objects") serializers.serialize( "json", @@ -20,7 +31,7 @@ def run(): ) -def get_objects(): +def get_objects(*args): for pr in PublishedRun.objects.filter( is_approved_example=True, visibility=PublishedRunVisibility.PUBLIC, @@ -28,37 +39,49 @@ def get_objects(): if pr.saved_run_id: yield export(pr.saved_run) if pr.created_by_id: - yield export(pr.created_by) + yield pr.created_by.handle + yield export(pr.created_by, only_include=USER_FIELDS) if pr.last_edited_by_id: - yield export(pr.last_edited_by) - yield export(pr, ["saved_run", "created_by", "last_edited_by"]) + yield pr.last_edited_by.handle + yield export(pr.last_edited_by, only_include=USER_FIELDS) + yield export(pr, include_fks={"saved_run", "created_by", "last_edited_by"}) for version in pr.versions.all(): yield export(version.saved_run) - yield export(version, ["published_run", "saved_run"]) + yield export(version, include_fks={"saved_run", "published_run"}) + if "bots" not in args: + return for obj in BotIntegration.objects.all(): user = AppUser.objects.get(uid=obj.billing_account_uid) yield user.handle - yield user + yield export(user, only_include=USER_FIELDS) if obj.saved_run_id: yield export(obj.saved_run) if obj.published_run_id: yield export(obj.published_run.saved_run) - yield export(obj.published_run, ["saved_run"]) + yield export(obj.published_run, include_fks={"saved_run"}) - yield export(obj, ["saved_run", "published_run"]) + yield export(obj, include_fks={"saved_run", "published_run"}) -def export(obj, exclude=()): +def export(obj, *, include_fks=(), only_include=None): for field in obj._meta.get_fields(): - if field.name in exclude: - continue - if field.is_relation: + if field.is_relation and field.name not in include_fks: try: setattr(obj, field.name, None) except TypeError: pass + elif only_include and field.name not in only_include: + if field.default == NOT_PROVIDED: + default = None + else: + default = field.default + try: + default = default() + except TypeError: + pass + setattr(obj, field.name, default) return obj diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh index aeb462d95..9c148abac 100755 --- a/scripts/run-tests.sh +++ b/scripts/run-tests.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash echo "==> Downloading fixture.json..." -wget -N -nv https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/8972b298-1206-11ef-aac6-02420a00010c/fixture.json +wget -N -nv https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/4f614770-446c-11ef-b36e-02420a000176/fixture.json echo "==> Linting with black..." black --check --diff . diff --git a/tests/test_integrations_api.py b/tests/test_integrations_api.py index 398fdd52a..e0e8caacf 100644 --- a/tests/test_integrations_api.py +++ b/tests/test_integrations_api.py @@ -1,21 +1,27 @@ import json -import pytest from furl import furl from starlette.testclient import TestClient -from bots.models import BotIntegration +from bots.models import BotIntegration, Workflow, Platform, SavedRun from server import app client = TestClient(app) -@pytest.mark.django_db -def test_send_msg_streaming(mock_celery_tasks, force_authentication): +def test_send_msg_streaming(transactional_db, mock_celery_tasks, force_authentication): + bi = BotIntegration.objects.create( + platform=Platform.WEB, + billing_account_uid=force_authentication.uid, + saved_run=SavedRun.objects.create( + workflow=Workflow.VIDEO_BOTS, + uid=force_authentication.uid, + ), + ) r = client.post( "/v3/integrations/stream/", json={ - "integration_id": BotIntegration.objects.first().api_integration_id(), + "integration_id": bi.api_integration_id(), "input_text": "hello, world", }, allow_redirects=False,