From 45f226adaea38d09987992311459e95a10f73750 Mon Sep 17 00:00:00 2001 From: Helen Lin Date: Mon, 25 Nov 2024 19:37:18 -0800 Subject: [PATCH] test: update unit tests --- src/aind_data_transfer_service/server.py | 1 - tests/test_server.py | 135 +++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 3f3783b..22cbf96 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -72,7 +72,6 @@ # AIND_AIRFLOW_SERVICE_PASSWORD # AIND_AIRFLOW_SERVICE_USER -# NOTE: add cors to test metadata-entry-app middleware = [ Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ] diff --git a/tests/test_server.py b/tests/test_server.py index 69e18c4..3cd0879 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1705,6 +1705,141 @@ def test_submit_v1_jobs_200_basic_serialization( ) self.assertEqual(200, submit_job_response.status_code) + @patch("pydantic.BaseModel.model_json_schema") + def test_get_json_schema_for_model( + self, + mock_model_json_schema: MagicMock, + ): + """Tests that the json schema for a model is returned.""" + mock_model_json_schema.return_value = {"foo": "bar"} + models = [ + "BasicUploadJobConfigsForm", + "SubmitJobRequestForm", + "BasicUploadJobConfigs", + "SubmitJobRequest", + ] + for model in models: + with TestClient(app) as client: + response = client.get(f"/api/v1/models/{model}/schema") + self.assertEqual(200, response.status_code) + self.assertEqual({"foo": "bar"}, response.json()) + self.assertEqual(4, mock_model_json_schema.call_count) + + @patch("pydantic.BaseModel.model_json_schema") + def test_get_json_schema_for_model_error( + self, + mock_model_json_schema: MagicMock, + ): + """Tests that 404 error is returned if model is not found.""" + model = "Test" + with TestClient(app) as client: + response = client.get(f"/api/v1/models/{model}/schema") + self.assertEqual(404, response.status_code) + self.assertEqual( + {"message": f"Schema not found for {model}"}, response.json() + ) + mock_model_json_schema.assert_not_called() + + def test_validate_json_for_model(self): + """Tests validate_json_for_model when json is valid.""" + ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165") + + s3_bucket = "private" + subject_id = "690165" + acq_datetime = datetime(2024, 2, 19, 11, 25, 17) + platform = Platform.ECEPHYS + + ephys_config = ModalityConfigs( + modality=Modality.ECEPHYS, + source=ephys_source_dir, + ) + project_name = "Ephys Platform" + + upload_job_configs = BasicUploadJobConfigs( + project_name=project_name, + s3_bucket=s3_bucket, + platform=platform, + subject_id=subject_id, + acq_datetime=acq_datetime, + modalities=[ephys_config], + ) + # validate BasicUploadJobConfigs + post_request_content = json.loads(upload_job_configs.model_dump_json()) + with TestClient(app) as client: + response = client.post( + "/api/v1/models/BasicUploadJobConfigs/validate", + json=post_request_content, + ) + response_json = response.json() + self.assertEqual(200, response.status_code) + self.assertEqual("Valid model", response_json["message"]) + self.assertEqual( + post_request_content, response_json["data"]["model_json"] + ) + + # validate SubmitJobRequest + submit_request = SubmitJobRequest(upload_jobs=[upload_job_configs]) + post_request_content = json.loads(submit_request.model_dump_json()) + + with TestClient(app) as client: + response = client.post( + "/api/v1/models/SubmitJobRequest/validate", + json=post_request_content, + ) + response_json = response.json() + self.assertEqual(200, response.status_code) + self.assertEqual("Valid model", response_json["message"]) + self.assertEqual( + post_request_content, response_json["data"]["model_json"] + ) + + def test_validate_json_for_model_not_found(self): + """Tests validate_json_for_model when model is not found.""" + model = "Test" + with TestClient(app) as client: + response = client.post( + f"/api/v1/models/{model}/validate", json={"foo": "bar"} + ) + response_json = response.json() + self.assertEqual(404, response.status_code) + self.assertEqual( + {"message": f"Model not found for {model}"}, response_json + ) + + def test_validate_json_for_model_invalid(self): + """Tests validate_json_for_model when json is invalid.""" + with TestClient(app) as client: + response = client.post( + "/api/v1/models/SubmitJobRequest/validate", + json={"foo": "bar"}, + ) + response_json = response.json() + self.assertEqual(406, response.status_code) + self.assertEqual( + "There were validation errors", response_json["message"] + ) + self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"]) + + @patch("pydantic.BaseModel.model_validate_json") + def test_validate_json_for_model_error( + self, mock_model_validate_json: MagicMock + ): + """Tests validate_json_for_model when there is an unknown error.""" + mock_model_validate_json.side_effect = Exception("Unknown error") + with TestClient(app) as client: + response = client.post( + "/api/v1/models/SubmitJobRequest/validate", + json={"foo": "bar"}, + ) + response_json = response.json() + self.assertEqual(500, response.status_code) + self.assertEqual( + "There was an internal server error", response_json["message"] + ) + self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"]) + self.assertEqual("('Unknown error',)", response_json["data"]["errors"]) + mock_model_validate_json.assert_called_once() + if __name__ == "__main__": unittest.main()