diff --git a/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py new file mode 100644 index 000000000..44033b955 --- /dev/null +++ b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.7 on 2024-07-12 19:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0076_alter_workflowmetadata_default_image_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='error_code', + field=models.IntegerField(blank=True, default=None, help_text='The HTTP status code of the error. If this is not set, 500 is assumed.', null=True), + ), + migrations.AddField( + model_name='savedrun', + name='error_type', + field=models.TextField(blank=True, default='', help_text='The exception type'), + ), + migrations.AlterField( + model_name='savedrun', + name='error_msg', + field=models.TextField(blank=True, default='', help_text='The error message. If this is not set, the run is deemed successful.'), + ), + ] diff --git a/conftest.py b/conftest.py index d355eb55e..a38c6a11a 100644 --- a/conftest.py +++ b/conftest.py @@ -51,9 +51,10 @@ def force_authentication(): @pytest.fixture -def mock_gui_runner(): +def mock_celery_tasks(): with ( patch("celeryapp.tasks.runner_task", _mock_runner_task), + patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks), patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe), ): yield @@ -70,6 +71,11 @@ def _mock_runner_task( _mock_realtime_push(channel, sr.to_dict()) +@app.task +def _mock_post_runner_tasks(*args, **kwargs): + pass + + def _mock_realtime_push(channel, value): redis_qs[channel].put(value) diff --git a/tests/test_apis.py b/tests/test_apis.py index bd3a915fb..fa897eb83 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -15,7 +15,7 @@ @pytest.mark.django_db -def test_apis_sync(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_sync, page_cls) @@ -32,7 +32,7 @@ def _test_api_sync(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_async(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_async, page_cls) @@ -65,7 +65,7 @@ def _test_api_async(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_examples(mock_celery_tasks, force_authentication, threadpool_subtest): qs = ( PublishedRun.objects.exclude(is_approved_example=False) .exclude(published_run_id="") diff --git a/tests/test_integrations_api.py b/tests/test_integrations_api.py index c6f11c9d8..398fdd52a 100644 --- a/tests/test_integrations_api.py +++ b/tests/test_integrations_api.py @@ -11,7 +11,7 @@ @pytest.mark.django_db -def test_send_msg_streaming(mock_gui_runner, force_authentication): +def test_send_msg_streaming(mock_celery_tasks, force_authentication): r = client.post( "/v3/integrations/stream/", json={