diff --git a/tests/test_service.py b/tests/test_service.py index 1214f76..2c1c37a 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -5,21 +5,25 @@ import pytest import requests from fastapi import HTTPException +import pendulum from proxy.exception import PrefectException from proxy.schemas import ( - AirbyteConnectionCreate, AirbyteServerCreate, DbtCoreCreate, DbtProfileCreate, - PrefectShellSetup, - DeploymentCreate, + DbtProfileUpdate, + DbtCliProfileBlockUpdate, + DeploymentCreate2, DeploymentUpdate, + DeploymentUpdate2, PrefectSecretBlockCreate, ) from proxy.service import ( _create_dbt_cli_profile, get_dbt_cli_profile, + update_dbt_cli_profile, + get_airbyte_server_block, create_airbyte_server_block, create_dbt_core_block, delete_airbyte_connection_block, @@ -44,12 +48,16 @@ update_postgres_credentials, update_bigquery_credentials, update_target_configs_schema, + post_deployment_v1, put_deployment, + put_deployment_v1, get_deployment, CronSchedule, post_deployment_flow_run, create_secret_block, cancel_flow_run, + get_flow_run_logs_v2, + retry_flow_run, ) @@ -304,6 +312,23 @@ def dict(self): return {"_block_document_id": "expected_block_id"} +@pytest.mark.asyncio +async def test_get_airbyte_server_block_paramcheck(): + blockname = "test_blockname" + with pytest.raises(TypeError) as excinfo: + await get_airbyte_server_block(1) + assert str(excinfo.value) == "blockname must be a string" + + +@pytest.mark.asyncio +async def test_get_airbyte_server_block(): + blockname = "test_blockname" + with patch("proxy.service.AirbyteServer.load", new_callable=AsyncMock) as mock_load: + mock_load.return_value = "expected_block_id" + result = await get_airbyte_server_block(blockname) + assert result == "expected_block_id" + + @pytest.mark.asyncio @patch("proxy.service.AirbyteServer", new=MockAirbyteServer) async def test_create_airbyte_server_block(): @@ -590,6 +615,116 @@ async def test_create_dbt_cli_profile_raises(mock_load): assert str(excinfo.value.detail) == "No dbt cli profile block named test_block_name" +@pytest.mark.asyncio +@patch("proxy.service.DbtCliProfile.load", new_callable=AsyncMock) +async def test_update_dbt_cli_profile(mock_load): + """tests update_dbt_cli_profile""" + mock_load.side_effect = ValueError("error") + with pytest.raises(PrefectException) as excinfo: + payload = DbtCliProfileBlockUpdate( + cli_profile_block_name="dne", wtype=None, profile=None, credentials=None + ) + await update_dbt_cli_profile(payload) + assert str(excinfo.value) == "no dbt cli profile block named dne" + + +@pytest.mark.asyncio +@patch("proxy.service.DbtCliProfile.load", new_callable=AsyncMock) +async def test_update_dbt_cli_profile_postgres(mock_load: AsyncMock): + """tests update_dbt_cli_profile""" + payload = DbtCliProfileBlockUpdate( + cli_profile_block_name="block-name", + profile=DbtProfileUpdate( + target_configs_schema="new_schema", name="profile-name" + ), + credentials={ + "host": "new_host", + "port": "new_port", + "database": "new_database", + "username": "new_username", + "password": "new_password", + }, + wtype="postgres", + ) + mock_load.return_value = Mock( + target_configs=Mock(schema="old-schema"), + save=AsyncMock(), + dict=Mock(return_value={"_block_document_id": "_block_document_id"}), + ) + block, block_id, block_name = await update_dbt_cli_profile(payload) + assert block_name == "block-name" + assert block_id == "_block_document_id" + assert block.target_configs.schema == "new_schema" + assert block.target == "new_schema" + assert block.name == "profile-name" + assert block.target_configs.extras["host"] == "new_host" + assert block.target_configs.extras["port"] == "new_port" + assert block.target_configs.extras["database"] == "new_database" + assert block.target_configs.extras["username"] == "new_username" + assert block.target_configs.extras["password"] == "new_password" + + +@pytest.mark.asyncio +@patch("proxy.service.DbtCliProfile.load", new_callable=AsyncMock) +async def test_update_dbt_cli_profile_postgres_override_target(mock_load: AsyncMock): + """tests update_dbt_cli_profile""" + payload = DbtCliProfileBlockUpdate( + cli_profile_block_name="block-name", + profile=DbtProfileUpdate( + target_configs_schema="new_schema", name="profile-name", target="override" + ), + credentials={ + "host": "new_host", + "port": "new_port", + "database": "new_database", + "username": "new_username", + "password": "new_password", + }, + wtype="postgres", + ) + mock_load.return_value = Mock( + target_configs=Mock(schema="old-schema"), + save=AsyncMock(), + dict=Mock(return_value={"_block_document_id": "_block_document_id"}), + ) + block, block_id, block_name = await update_dbt_cli_profile(payload) + assert block_name == "block-name" + assert block_id == "_block_document_id" + assert block.target == "override" + + +@pytest.mark.asyncio +@patch("proxy.service.DbtCliProfile.load", new_callable=AsyncMock) +@patch("proxy.service.GcpCredentials", Mock(return_value={})) +async def test_update_dbt_cli_profile_bigquery(mock_load: AsyncMock): + """tests update_dbt_cli_profile""" + service_account_info = { + "token_uri": "token_uri", + "client_email": "client_email", + "private_key": "private key", + } + payload = DbtCliProfileBlockUpdate( + cli_profile_block_name="block-name", + profile=DbtProfileUpdate( + target_configs_schema="new_schema", name="profile-name" + ), + credentials=service_account_info, + wtype="bigquery", + bqlocation="bq-location", + ) + mock_load.return_value = Mock( + target_configs=Mock(schema="old-schema"), + save=AsyncMock(), + dict=Mock(return_value={"_block_document_id": "_block_document_id"}), + ) + block, block_id, block_name = await update_dbt_cli_profile(payload) + assert block_name == "block-name" + assert block_id == "_block_document_id" + assert block.target_configs.extras == {"location": "bq-location"} + assert block.target == "new_schema" + assert block.name == "profile-name" + + @pytest.mark.asyncio @patch("proxy.service.DbtCliProfile.save", new_callable=AsyncMock) @patch("proxy.service.DbtCoreOperation.__init__", return_value=None) @@ -877,6 +1012,73 @@ async def test_update_target_configs_schema(mock_load): assert dbt_coreop_block.commands[0] == "dbt run --target newtarget" +async def test_post_deployment_bad_param(): + with pytest.raises(TypeError) as excinfo: + await post_deployment_v1("deployment-id") + assert str(excinfo.value) == "payload must be a DeploymentCreate" + + +@pytest.mark.asyncio +@patch("proxy.service.deployment_schedule_flow_v4", new_callable=Mock) +@patch("proxy.service.Deployment.build_from_flow", new_callable=AsyncMock) +async def test_post_deployment_1( + mock_build_from_flow, mock_deployment_schedule_flow_v4 +): + payload = DeploymentCreate2( + work_queue_name="queue-name", + work_pool_name="pool-name", + flow_name="flow-name", + flow_id="flow-id", + deployment_name="deployment-name", + org_slug="org-slug", + deployment_params={"param1": "value1"}, + cron="* * * * *", + ) + mock_deployment_schedule_flow_v4.with_options = Mock(return_value="flow") + mock_deployment = Mock( + name="deployment-name", + parameters={"param1": "value1"}, + schedule=None, + apply=AsyncMock(return_value="deployment-id"), + ) + mock_build_from_flow.return_value = mock_deployment + deployment = await post_deployment_v1(payload) + mock_build_from_flow.assert_called_once_with( + flow="flow", + name=payload.deployment_name, + work_queue_name="queue-name", + work_pool_name="pool-name", + tags=[payload.org_slug], + is_schedule_active=True, + ) + mock_deployment_schedule_flow_v4.with_options.assert_called_with( + name=payload.flow_name + ) + assert deployment["id"] == "deployment-id" + # assert retval["name"] == "deployment-name" + assert deployment["params"] == mock_deployment.parameters + + +@patch("proxy.service.prefect_patch") +def test_put_deployment_v1(mock_prefect_patch): + payload = DeploymentUpdate2( + deployment_params={"param1": "value1"}, + cron="* * * * *", + work_pool_name="pool-name", + work_queue_name="queue-name", + ) + put_deployment_v1("deployment-id", payload) + mock_prefect_patch.assert_called_once_with( + "deployments/deployment-id", + { + "schedule": CronSchedule(cron="* * * * *").dict(), + "parameters": {"param1": "value1"}, + "work_pool_name": "pool-name", + "work_queue_name": "queue-name", + }, + ) + + def test_put_deployment_bad_param(): payload = 123 with pytest.raises(TypeError) as excinfo: @@ -935,7 +1137,9 @@ def test_get_flow_runs_by_deployment_id_prefect_post(): "deployments": {"id": {"any_": [deployment_id]}}, "flow_runs": { "operator": "and_", - "state": {"type": {"any_": ["COMPLETED", "FAILED"]}}, + "state": { + "type": {"any_": ["COMPLETED", "FAILED", "CRASHED", "CANCELLED"]} + }, }, "limit": limit, } @@ -1137,6 +1341,104 @@ def test_get_flow_run_logs_prefect_post(): prefect_post_mock.assert_called_with("logs/filter", query) +@patch("proxy.service.prefect_get") +@patch("proxy.service.prefect_post") +@patch("proxy.service.traverse_flow_run_graph_v2") +def test_get_flow_run_logs_v2_flow_run( + mock_traverse_flow_run_graph_v2: Mock, + mock_prefect_post: Mock, + mock_prefect_get: Mock, +): + mock_traverse_flow_run_graph_v2.return_value = [ + { + "kind": "flow-run", + "id": "run-id", + "label": "run-label", + "start_time": "start-time", + "end_time": "end-time", + } + ] + mock_prefect_get.return_value = { + "state_name": "state-name", + "state_type": "state-type", + } + mock_prefect_post.return_value = [] + retval = get_flow_run_logs_v2("flow_run_id") + mock_prefect_get.assert_called_once_with("flow_runs/run-id") + mock_prefect_post.assert_called_once_with( + "logs/filter", + { + "logs": { + "operator": "or_", + "flow_run_id": {"any_": ["run-id"]}, + "task_run_id": {"any_": []}, + }, + "sort": "TIMESTAMP_ASC", + }, + ) + assert retval == [ + { + "id": "run-id", + "kind": "flow-run", + "label": "run-label", + "state_type": "state-type", + "state_name": "state-name", + "start_time": "start-time", + "end_time": "end-time", + "logs": [], + } + ] + + +@patch("proxy.service.prefect_get") +@patch("proxy.service.prefect_post") +@patch("proxy.service.traverse_flow_run_graph_v2") +def test_get_flow_run_logs_v2_task_run( + mock_traverse_flow_run_graph_v2: Mock, + mock_prefect_post: Mock, + mock_prefect_get: Mock, +): + mock_traverse_flow_run_graph_v2.return_value = [ + { + "kind": "task-run", + "id": "run-id", + "label": "run-label", + "start_time": "start-time", + "end_time": "end-time", + } + ] + mock_prefect_get.return_value = { + "state_name": "state-name", + "state_type": "state-type", + } + mock_prefect_post.return_value = [] + retval = get_flow_run_logs_v2("flow_run_id") + mock_prefect_get.assert_called_once_with("task_runs/run-id") + mock_prefect_post.assert_called_once_with( + "logs/filter", + { + "logs": { + "operator": "or_", + "flow_run_id": {"any_": []}, + "task_run_id": {"any_": ["run-id"]}, + }, + "sort": "TIMESTAMP_ASC", + }, + ) + assert retval == [ + { + "id": "run-id", + "kind": "task-run", + "label": "run-label", + "state_type": "state-type", + "state_name": "state-name", + "start_time": "start-time", + "end_time": "end-time", + "logs": [], + } + ] + + def test_get_flow_runs_by_name_type_error(): with pytest.raises(TypeError): get_flow_runs_by_name(123) @@ -1189,14 +1491,15 @@ def test_set_deployment_schedule_prefect_post(): @patch("proxy.service.prefect_get") -@patch("proxy.service.get_final_state_for_flow_run") -def test_get_flow_run_success(mock_get_final_state: Mock, mock_get: Mock): - mock_get.return_value = {"id": "12345", "state": {"type": "COMPLETED"}} - mock_get_final_state.return_value = "COMPLETED" +@patch("proxy.service.update_flow_run_final_state") +def test_get_flow_run_success(mock_update_flow_run_final_state: Mock, mock_get: Mock): + mock_get.return_value = {"id": "flow-run-id", "state": {"type": "COMPLETED"}} + mock_update_flow_run_final_state.return_value = {"id": "flow-run-id", "state": {"type": "COMPLETED"}, "state_name": "COMPLETED",} response = get_flow_run("flow-run-id") mock_get.assert_called_once_with("flow_runs/flow-run-id") + mock_update_flow_run_final_state.assert_called_once_with({"id": "flow-run-id", "state": {"type": "COMPLETED"}}) assert response == { - "id": "12345", + "id": "flow-run-id", "state": {"type": "COMPLETED"}, "state_name": "COMPLETED", } @@ -1235,3 +1538,33 @@ async def test_cancel_flow_run_success(): flow_run_id = "valid_flow_run_id" result = await cancel_flow_run(flow_run_id) assert result is None + + +def test_retry_flow_run_bad_param(): + with pytest.raises(TypeError) as excinfo: + retry_flow_run(123) + assert str(excinfo.value) == "flow_run_id must be a string" + + +@patch("proxy.service.prefect_post") +@patch("proxy.service.pendulum") +def test_retry_flow_run(mock_pendulum: Mock, mock_prefect_post: Mock): + mock_pendulum.now.return_value = pendulum.time(0, 0, 0) + mock_pendulum.duration.return_value = pendulum.duration(minutes=5) + retry_flow_run("flow-run-id") + mock_prefect_post.assert_called_once_with( + "flow_runs/flow-run-id/set_state", + { + "force": True, + "state": { + "name": "AwaitingRetry", + "message": "Retry via prefect proxy", + "type": "SCHEDULED", + "state_details": { + "scheduled_time": str( + pendulum.time(0, 0, 0) + pendulum.duration(minutes=5) + ) + }, # using pendulum because prefect also uses it + }, + }, + )