Skip to content

Commit

Permalink
feat: endpoint to validate json (#187)
Browse files Browse the repository at this point in the history
* feat: POST /api/v1/validate_json

* docs: add validate example in User Guide
  • Loading branch information
helen-m-lin authored Dec 18, 2024
1 parent c0f7415 commit 9f4928c
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/UserGuide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ endpoint:
)
post_request_content = json.loads(submit_request.model_dump_json(exclude_none=True))
# Optionally validate the submit_request before submitting
validate_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/validate_json", json=post_request_content)
print(validate_job_response.status_code)
print(validate_job_response.json())
# Uncomment the following to submit the request
# submit_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/submit_jobs", json=post_request_content)
# print(submit_job_response.status_code)
Expand Down
56 changes: 56 additions & 0 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing import Optional

import requests
from aind_data_transfer_models import (
__version__ as aind_data_transfer_models_version,
)
from aind_data_transfer_models.core import SubmitJobRequest
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -173,6 +176,58 @@ async def validate_csv_legacy(request: Request):
)


async def validate_json(request: Request):
"""Validate raw json against aind-data-transfer-models. Returns validated
json or errors if request is invalid."""
logger.info("Received request to validate json")
content = await request.json()
try:
validated_model = SubmitJobRequest.model_validate_json(
json.dumps(content)
)
validated_content = json.loads(
validated_model.model_dump_json(warnings=False, exclude_none=True)
)
logger.info("Valid model detected")
return JSONResponse(
status_code=200,
content={
"message": "Valid model",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"validated_model_json": validated_content,
},
},
)
except ValidationError as e:
logger.warning(f"There were validation errors processing {content}")
return JSONResponse(
status_code=406,
content={
"message": "There were validation errors",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": e.json(),
},
},
)
except Exception as e:
logger.exception("Internal Server Error.")
return JSONResponse(
status_code=500,
content={
"message": "There was an internal server error",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": str(e.args),
},
},
)


async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
logger.info("Received request to submit jobs")
Expand Down Expand Up @@ -759,6 +814,7 @@ async def download_job_template(_: Request):
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
),
Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]),
Route("/api/v1/validate_json", endpoint=validate_json, methods=["POST"]),
Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]),
Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]),
Route(
Expand Down
97 changes: 97 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from aind_data_transfer_models import (
__version__ as aind_data_transfer_models_version,
)
from aind_data_transfer_models.core import (
BasicUploadJobConfigs,
ModalityConfigs,
Expand Down Expand Up @@ -1365,10 +1368,12 @@ def test_validate_v1_malformed_csv2(self):
self.assertEqual(response.status_code, 406)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.Logger.warning")
@patch("requests.post")
def test_submit_v1_jobs_406(
self,
mock_post: MagicMock,
mock_log_warning: MagicMock,
):
"""Tests submit jobs 406 response."""
with TestClient(app) as client:
Expand All @@ -1377,6 +1382,9 @@ def test_submit_v1_jobs_406(
)
self.assertEqual(406, submit_job_response.status_code)
mock_post.assert_not_called()
mock_log_warning.assert_called_once_with(
"There were validation errors processing {}"
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
Expand Down Expand Up @@ -1698,6 +1706,95 @@ def test_submit_v1_jobs_200_basic_serialization(
)
self.assertEqual(200, submit_job_response.status_code)

def test_validate_json(self):
"""Tests validate_json when json is valid."""
ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165")

s3_bucket = "private"
subject_id = "690165"
acq_datetime = datetime(2024, 2, 19, 11, 25, 17)
platform = Platform.ECEPHYS

ephys_config = ModalityConfigs(
modality=Modality.ECEPHYS,
source=ephys_source_dir,
)
project_name = "Ephys Platform"

upload_job_configs = BasicUploadJobConfigs(
project_name=project_name,
s3_bucket=s3_bucket,
platform=platform,
subject_id=subject_id,
acq_datetime=acq_datetime,
modalities=[ephys_config],
)
submit_job_request = SubmitJobRequest(upload_jobs=[upload_job_configs])
post_request_content = json.loads(submit_job_request.model_dump_json())
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json=post_request_content,
)
response_json = response.json()
self.assertEqual(200, response.status_code)
self.assertEqual("Valid model", response_json["message"])
self.assertEqual(
post_request_content, response_json["data"]["model_json"]
)
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)

@patch("logging.Logger.warning")
def test_validate_json_invalid(self, mock_log_warning: MagicMock):
"""Tests validate_json when json is invalid."""
content = {"foo": "bar"}
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json=content,
)
response_json = response.json()
self.assertEqual(406, response.status_code)
self.assertEqual(
"There were validation errors", response_json["message"]
)
self.assertEqual(content, response_json["data"]["model_json"])
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)
mock_log_warning.assert_called_once_with(
f"There were validation errors processing {content}"
)

@patch("logging.Logger.exception")
@patch("pydantic.BaseModel.model_validate_json")
def test_validate_json_error(
self,
mock_model_validate_json: MagicMock,
mock_log_error: MagicMock,
):
"""Tests validate_json when there is an unknown error."""
mock_model_validate_json.side_effect = Exception("Unknown error")
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json={"foo": "bar"},
)
response_json = response.json()
self.assertEqual(500, response.status_code)
self.assertEqual(
"There was an internal server error", response_json["message"]
)
self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"])
self.assertEqual("('Unknown error',)", response_json["data"]["errors"])
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)
mock_model_validate_json.assert_called_once()
mock_log_error.assert_called_once_with("Internal Server Error.")


if __name__ == "__main__":
unittest.main()

0 comments on commit 9f4928c

Please sign in to comment.