Skip to content

Commit

Permalink
test: update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
helen-m-lin committed Nov 26, 2024
1 parent 9b1da6a commit 45f226a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
# AIND_AIRFLOW_SERVICE_PASSWORD
# AIND_AIRFLOW_SERVICE_USER

# NOTE: add cors to test metadata-entry-app
middleware = [
Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
]
Expand Down
135 changes: 135 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,141 @@ def test_submit_v1_jobs_200_basic_serialization(
)
self.assertEqual(200, submit_job_response.status_code)

@patch("pydantic.BaseModel.model_json_schema")
def test_get_json_schema_for_model(
self,
mock_model_json_schema: MagicMock,
):
"""Tests that the json schema for a model is returned."""
mock_model_json_schema.return_value = {"foo": "bar"}
models = [
"BasicUploadJobConfigsForm",
"SubmitJobRequestForm",
"BasicUploadJobConfigs",
"SubmitJobRequest",
]
for model in models:
with TestClient(app) as client:
response = client.get(f"/api/v1/models/{model}/schema")
self.assertEqual(200, response.status_code)
self.assertEqual({"foo": "bar"}, response.json())
self.assertEqual(4, mock_model_json_schema.call_count)

@patch("pydantic.BaseModel.model_json_schema")
def test_get_json_schema_for_model_error(
self,
mock_model_json_schema: MagicMock,
):
"""Tests that 404 error is returned if model is not found."""
model = "Test"
with TestClient(app) as client:
response = client.get(f"/api/v1/models/{model}/schema")
self.assertEqual(404, response.status_code)
self.assertEqual(
{"message": f"Schema not found for {model}"}, response.json()
)
mock_model_json_schema.assert_not_called()

def test_validate_json_for_model(self):
"""Tests validate_json_for_model when json is valid."""
ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165")

s3_bucket = "private"
subject_id = "690165"
acq_datetime = datetime(2024, 2, 19, 11, 25, 17)
platform = Platform.ECEPHYS

ephys_config = ModalityConfigs(
modality=Modality.ECEPHYS,
source=ephys_source_dir,
)
project_name = "Ephys Platform"

upload_job_configs = BasicUploadJobConfigs(
project_name=project_name,
s3_bucket=s3_bucket,
platform=platform,
subject_id=subject_id,
acq_datetime=acq_datetime,
modalities=[ephys_config],
)
# validate BasicUploadJobConfigs
post_request_content = json.loads(upload_job_configs.model_dump_json())
with TestClient(app) as client:
response = client.post(
"/api/v1/models/BasicUploadJobConfigs/validate",
json=post_request_content,
)
response_json = response.json()
self.assertEqual(200, response.status_code)
self.assertEqual("Valid model", response_json["message"])
self.assertEqual(
post_request_content, response_json["data"]["model_json"]
)

# validate SubmitJobRequest
submit_request = SubmitJobRequest(upload_jobs=[upload_job_configs])
post_request_content = json.loads(submit_request.model_dump_json())

with TestClient(app) as client:
response = client.post(
"/api/v1/models/SubmitJobRequest/validate",
json=post_request_content,
)
response_json = response.json()
self.assertEqual(200, response.status_code)
self.assertEqual("Valid model", response_json["message"])
self.assertEqual(
post_request_content, response_json["data"]["model_json"]
)

def test_validate_json_for_model_not_found(self):
"""Tests validate_json_for_model when model is not found."""
model = "Test"
with TestClient(app) as client:
response = client.post(
f"/api/v1/models/{model}/validate", json={"foo": "bar"}
)
response_json = response.json()
self.assertEqual(404, response.status_code)
self.assertEqual(
{"message": f"Model not found for {model}"}, response_json
)

def test_validate_json_for_model_invalid(self):
"""Tests validate_json_for_model when json is invalid."""
with TestClient(app) as client:
response = client.post(
"/api/v1/models/SubmitJobRequest/validate",
json={"foo": "bar"},
)
response_json = response.json()
self.assertEqual(406, response.status_code)
self.assertEqual(
"There were validation errors", response_json["message"]
)
self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"])

@patch("pydantic.BaseModel.model_validate_json")
def test_validate_json_for_model_error(
self, mock_model_validate_json: MagicMock
):
"""Tests validate_json_for_model when there is an unknown error."""
mock_model_validate_json.side_effect = Exception("Unknown error")
with TestClient(app) as client:
response = client.post(
"/api/v1/models/SubmitJobRequest/validate",
json={"foo": "bar"},
)
response_json = response.json()
self.assertEqual(500, response.status_code)
self.assertEqual(
"There was an internal server error", response_json["message"]
)
self.assertEqual({"foo": "bar"}, response_json["data"]["model_json"])
self.assertEqual("('Unknown error',)", response_json["data"]["errors"])
mock_model_validate_json.assert_called_once()


if __name__ == "__main__":
unittest.main()

0 comments on commit 45f226a

Please sign in to comment.