diff --git a/docs/source/UserGuide.rst b/docs/source/UserGuide.rst index be4bc7b..f9054ec 100644 --- a/docs/source/UserGuide.rst +++ b/docs/source/UserGuide.rst @@ -151,6 +151,10 @@ endpoint: ) post_request_content = json.loads(submit_request.model_dump_json(exclude_none=True)) + # Optionally validate the submit_request before submitting + validate_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/validate_json", json=post_request_content) + print(validate_job_response.status_code) + print(validate_job_response.json()) # Uncomment the following to submit the request # submit_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/submit_jobs", json=post_request_content) # print(submit_job_response.status_code) diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 017e9ff..bed3895 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -9,6 +9,9 @@ from typing import Optional import requests +from aind_data_transfer_models import ( + __version__ as aind_data_transfer_models_version, +) from aind_data_transfer_models.core import SubmitJobRequest from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse @@ -173,6 +176,58 @@ async def validate_csv_legacy(request: Request): ) +async def validate_json(request: Request): + """Validate raw json against aind-data-transfer-models. Returns validated + json or errors if request is invalid.""" + logger.info("Received request to validate json") + content = await request.json() + try: + validated_model = SubmitJobRequest.model_validate_json( + json.dumps(content) + ) + validated_content = json.loads( + validated_model.model_dump_json(warnings=False, exclude_none=True) + ) + logger.info("Valid model detected") + return JSONResponse( + status_code=200, + content={ + "message": "Valid model", + "data": { + "version": aind_data_transfer_models_version, + "model_json": content, + "validated_model_json": validated_content, + }, + }, + ) + except ValidationError as e: + logger.warning(f"There were validation errors processing {content}") + return JSONResponse( + status_code=406, + content={ + "message": "There were validation errors", + "data": { + "version": aind_data_transfer_models_version, + "model_json": content, + "errors": e.json(), + }, + }, + ) + except Exception as e: + logger.exception("Internal Server Error.") + return JSONResponse( + status_code=500, + content={ + "message": "There was an internal server error", + "data": { + "version": aind_data_transfer_models_version, + "model_json": content, + "errors": str(e.args), + }, + }, + ) + + async def submit_jobs(request: Request): """Post BasicJobConfigs raw json to hpc server to process.""" logger.info("Received request to submit jobs") @@ -759,6 +814,7 @@ async def download_job_template(_: Request): "/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"] ), Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]), + Route("/api/v1/validate_json", endpoint=validate_json, methods=["POST"]), Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]), Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]), Route( diff --git a/tests/test_server.py b/tests/test_server.py index 1eac559..0af9999 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,6 +12,9 @@ from aind_data_schema_models.modalities import Modality from aind_data_schema_models.platforms import Platform +from aind_data_transfer_models import ( + __version__ as aind_data_transfer_models_version, +) from aind_data_transfer_models.core import ( BasicUploadJobConfigs, ModalityConfigs, @@ -1365,10 +1368,12 @@ def test_validate_v1_malformed_csv2(self): self.assertEqual(response.status_code, 406) @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("logging.Logger.warning") @patch("requests.post") def test_submit_v1_jobs_406( self, mock_post: MagicMock, + mock_log_warning: MagicMock, ): """Tests submit jobs 406 response.""" with TestClient(app) as client: @@ -1377,6 +1382,9 @@ def test_submit_v1_jobs_406( ) self.assertEqual(406, submit_job_response.status_code) mock_post.assert_not_called() + mock_log_warning.assert_called_once_with( + "There were validation errors processing {}" + ) @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) @patch("requests.post") @@ -1698,6 +1706,95 @@ def test_submit_v1_jobs_200_basic_serialization( ) self.assertEqual(200, submit_job_response.status_code) + def test_validate_json(self): + """Tests validate_json when json is valid.""" + ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165") + + s3_bucket = "private" + subject_id = "690165" + acq_datetime = datetime(2024, 2, 19, 11, 25, 17) + platform = Platform.ECEPHYS + + ephys_config = ModalityConfigs( + modality=Modality.ECEPHYS, + source=ephys_source_dir, + ) + project_name = "Ephys Platform" + + upload_job_configs = BasicUploadJobConfigs( + project_name=project_name, + s3_bucket=s3_bucket, + platform=platform, + subject_id=subject_id, + acq_datetime=acq_datetime, + modalities=[ephys_config], + ) + submit_job_request = SubmitJobRequest(upload_jobs=[upload_job_configs]) + post_request_content = json.loads(submit_job_request.model_dump_json()) + with TestClient(app) as client: + response = client.post( + "/api/v1/validate_json", + json=post_request_content, + ) + response_json = response.json() + self.assertEqual(200, response.status_code) + self.assertEqual("Valid model", response_json["message"]) + self.assertEqual( + post_request_content, response_json["data"]["model_json"] + ) + self.assertEqual( + aind_data_transfer_models_version, response_json["data"]["version"] + ) + + @patch("logging.Logger.warning") + def test_validate_json_invalid(self, mock_log_warning: MagicMock): + """Tests validate_json when json is invalid.""" + content = {"foo": "bar"} + with TestClient(app) as client: + response = client.post( + "/api/v1/validate_json", + json=content, + ) + response_json = response.json() + self.assertEqual(406, response.status_code) + self.assertEqual( + "There were validation errors", response_json["message"] + ) + self.assertEqual(content, response_json["data"]["model_json"]) + self.assertEqual( + aind_data_transfer_models_version, response_json["data"]["version"] + ) + mock_log_warning.assert_called_once_with( + f"There were validation errors processing {content}" + ) + + @patch("logging.Logger.exception") + @patch("pydantic.BaseModel.model_validate_json") + def test_validate_json_error( + self, + mock_model_validate_json: MagicMock, + mock_log_error: MagicMock, + ): + """Tests validate_json when there is an unknown error.""" + mock_model_validate_json.side_effect = Exception("Unknown error") + with TestClient(app) as client: + response = client.post( + "/api/v1/validate_json", + json={"foo": "bar"}, + ) + response_json = response.json() + self.assertEqual(500, response.status_code) + self.assertEqual( + "There was an internal server error", response_json["message"] + ) + self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"]) + self.assertEqual("('Unknown error',)", response_json["data"]["errors"]) + self.assertEqual( + aind_data_transfer_models_version, response_json["data"]["version"] + ) + mock_model_validate_json.assert_called_once() + mock_log_error.assert_called_once_with("Internal Server Error.") + if __name__ == "__main__": unittest.main()