Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: endpoint to validate json #187

Merged
merged 11 commits into from
Dec 18, 2024
4 changes: 4 additions & 0 deletions docs/source/UserGuide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ endpoint:
)

post_request_content = json.loads(submit_request.model_dump_json(exclude_none=True))
# Optionally validate the submit_request before submitting
validate_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/validate_json", json=post_request_content)
print(validate_job_response.status_code)
print(validate_job_response.json())
# Uncomment the following to submit the request
# submit_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/submit_jobs", json=post_request_content)
# print(submit_job_response.status_code)
Expand Down
56 changes: 56 additions & 0 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing import Optional

import requests
from aind_data_transfer_models import (
__version__ as aind_data_transfer_models_version,
)
from aind_data_transfer_models.core import SubmitJobRequest
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -173,6 +176,58 @@ async def validate_csv_legacy(request: Request):
)


async def validate_json(request: Request):
"""Validate raw json against aind-data-transfer-models. Returns validated
json or errors if request is invalid."""
logger.info("Received request to validate json")
content = await request.json()
try:
validated_model = SubmitJobRequest.model_validate_json(
json.dumps(content)
)
validated_content = json.loads(
validated_model.model_dump_json(warnings=False, exclude_none=True)
)
logger.info("Valid model detected")
return JSONResponse(
status_code=200,
content={
"message": "Valid model",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"validated_model_json": validated_content,
},
},
)
except ValidationError as e:
logger.warning(f"There were validation errors processing {content}")
return JSONResponse(
status_code=406,
content={
"message": "There were validation errors",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": e.json(),
},
},
)
except Exception as e:
logger.exception("Internal Server Error.")
return JSONResponse(
status_code=500,
content={
"message": "There was an internal server error",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": str(e.args),
},
},
)


async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
logger.info("Received request to submit jobs")
Expand Down Expand Up @@ -759,6 +814,7 @@ async def download_job_template(_: Request):
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
),
Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]),
Route("/api/v1/validate_json", endpoint=validate_json, methods=["POST"]),
Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]),
Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]),
Route(
Expand Down
97 changes: 97 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from aind_data_transfer_models import (
__version__ as aind_data_transfer_models_version,
)
from aind_data_transfer_models.core import (
BasicUploadJobConfigs,
ModalityConfigs,
Expand Down Expand Up @@ -1365,10 +1368,12 @@ def test_validate_v1_malformed_csv2(self):
self.assertEqual(response.status_code, 406)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.Logger.warning")
@patch("requests.post")
def test_submit_v1_jobs_406(
self,
mock_post: MagicMock,
mock_log_warning: MagicMock,
):
"""Tests submit jobs 406 response."""
with TestClient(app) as client:
Expand All @@ -1377,6 +1382,9 @@ def test_submit_v1_jobs_406(
)
self.assertEqual(406, submit_job_response.status_code)
mock_post.assert_not_called()
mock_log_warning.assert_called_once_with(
"There were validation errors processing {}"
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
Expand Down Expand Up @@ -1698,6 +1706,95 @@ def test_submit_v1_jobs_200_basic_serialization(
)
self.assertEqual(200, submit_job_response.status_code)

def test_validate_json(self):
"""Tests validate_json when json is valid."""
ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165")

s3_bucket = "private"
subject_id = "690165"
acq_datetime = datetime(2024, 2, 19, 11, 25, 17)
platform = Platform.ECEPHYS

ephys_config = ModalityConfigs(
modality=Modality.ECEPHYS,
source=ephys_source_dir,
)
project_name = "Ephys Platform"

upload_job_configs = BasicUploadJobConfigs(
project_name=project_name,
s3_bucket=s3_bucket,
platform=platform,
subject_id=subject_id,
acq_datetime=acq_datetime,
modalities=[ephys_config],
)
submit_job_request = SubmitJobRequest(upload_jobs=[upload_job_configs])
post_request_content = json.loads(submit_job_request.model_dump_json())
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json=post_request_content,
)
response_json = response.json()
self.assertEqual(200, response.status_code)
self.assertEqual("Valid model", response_json["message"])
self.assertEqual(
post_request_content, response_json["data"]["model_json"]
)
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)

@patch("logging.Logger.warning")
def test_validate_json_invalid(self, mock_log_warning: MagicMock):
"""Tests validate_json when json is invalid."""
content = {"foo": "bar"}
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json=content,
)
response_json = response.json()
self.assertEqual(406, response.status_code)
self.assertEqual(
"There were validation errors", response_json["message"]
)
self.assertEqual(content, response_json["data"]["model_json"])
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)
mock_log_warning.assert_called_once_with(
f"There were validation errors processing {content}"
)

@patch("logging.Logger.exception")
@patch("pydantic.BaseModel.model_validate_json")
def test_validate_json_error(
self,
mock_model_validate_json: MagicMock,
mock_log_error: MagicMock,
):
"""Tests validate_json when there is an unknown error."""
mock_model_validate_json.side_effect = Exception("Unknown error")
with TestClient(app) as client:
response = client.post(
"/api/v1/validate_json",
json={"foo": "bar"},
)
response_json = response.json()
self.assertEqual(500, response.status_code)
self.assertEqual(
"There was an internal server error", response_json["message"]
)
self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"])
self.assertEqual("('Unknown error',)", response_json["data"]["errors"])
self.assertEqual(
aind_data_transfer_models_version, response_json["data"]["version"]
)
mock_model_validate_json.assert_called_once()
mock_log_error.assert_called_once_with("Internal Server Error.")


if __name__ == "__main__":
unittest.main()
Loading