Skip to content

Commit

Permalink
feat: adds ability to write to open-data account (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 authored Feb 1, 2024
1 parent 9d6f031 commit 4e81a81
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 7 deletions.
22 changes: 18 additions & 4 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
# HPC_STAGING_DIRECTORY
# HPC_AWS_PARAM_STORE_NAME
# BASIC_JOB_SCRIPT
# OPEN_DATA_AWS_SECRET_ACCESS_KEY
# OPEN_DATA_AWS_ACCESS_KEY_ID

OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "aind-open-data")


async def validate_csv(request: Request):
Expand Down Expand Up @@ -176,17 +180,27 @@ async def submit_hpc_jobs(request: Request): # noqa: C901
job["upload_job_settings"]
).s3_prefix
upload_job_configs = json.loads(job["upload_job_settings"])
# The aws creds to use are different for aind-open-data and
# everything else
if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME:
aws_secret_access_key = SecretStr(
os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY")
)
aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID")
else:
aws_secret_access_key = SecretStr(
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
)
aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID")
hpc_settings = json.loads(job["hpc_settings"])
if basic_job_name is not None:
hpc_settings["name"] = basic_job_name
hpc_job = HpcJobSubmitSettings.from_upload_job_configs(
logging_directory=PurePosixPath(
os.getenv("HPC_LOGGING_DIRECTORY")
),
aws_secret_access_key=SecretStr(
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
),
aws_access_key_id=os.getenv("HPC_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"),
aws_session_token=(
(
Expand Down
79 changes: 76 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import os
import unittest
from copy import deepcopy
from pathlib import Path
from pathlib import Path, PurePosixPath
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient
from pydantic import SecretStr
from requests import Response

from aind_data_transfer_service.server import app
Expand Down Expand Up @@ -41,6 +42,8 @@ class TestServer(unittest.TestCase):
"APP_SECRET_KEY": "test_app_key",
"HPC_STAGING_DIRECTORY": "/stage/dir",
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
"OPEN_DATA_AWS_SECRET_ACCESS_KEY": "open_data_aws_key",
"OPEN_DATA_AWS_ACCESS_KEY_ID": "open_data_aws_key_id",
}

with open(SAMPLE_CSV, "r") as file:
Expand Down Expand Up @@ -229,7 +232,7 @@ def test_validate_malformed_csv(self):
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
[("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")],
["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"],
response.json()["data"]["errors"],
)

Expand All @@ -244,7 +247,7 @@ def test_validate_malformed_xlsx(self):
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
[("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")],
["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"],
response.json()["data"]["errors"],
)

Expand Down Expand Up @@ -305,6 +308,76 @@ def test_submit_hpc_jobs(
self.assertEqual(200, submit_job_response.status_code)
self.assertEqual(2, mock_sleep.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch(
"aind_data_transfer_service.hpc.models.HpcJobSubmitSettings"
".from_upload_job_configs"
)
def test_submit_hpc_jobs_open_data(
self,
mock_from_upload_configs: MagicMock,
mock_submit_job: MagicMock,
mock_sleep: MagicMock,
):
"""Tests submit hpc jobs success."""
mock_response = Response()
mock_response.status_code = 200
mock_response._content = b'{"message": "success"}'
mock_submit_job.return_value = mock_response
# When a user specifies aind-open-data in the upload_job_settings,
# use the credentials for that account.
post_request_content = {
"jobs": [
{
"hpc_settings": '{"qos":"production", "name": "job1"}',
"upload_job_settings": (
'{"s3_bucket": "aind-open-data", '
'"platform": {"name": "Behavior platform", '
'"abbreviation": "behavior"}, '
'"modalities": ['
'{"modality": {"name": "Behavior videos", '
'"abbreviation": "behavior-videos"}, '
'"source": "dir/data_set_2", '
'"compress_raw_data": true, '
'"skip_staging": false}], '
'"subject_id": "123456", '
'"acq_datetime": "2020-10-13T13:10:10", '
'"process_name": "Other", '
'"log_level": "WARNING", '
'"metadata_dir_force": false, '
'"dry_run": false, '
'"force_cloud_sync": false}'
),
"script": "",
}
]
}
with TestClient(app) as client:
submit_job_response = client.post(
url="/api/submit_hpc_jobs", json=post_request_content
)
expected_response = {
"message": "Submitted Jobs.",
"data": {
"responses": [{"message": "success"}],
"errors": [],
},
}
self.assertEqual(expected_response, submit_job_response.json())
self.assertEqual(200, submit_job_response.status_code)
self.assertEqual(1, mock_sleep.call_count)
mock_from_upload_configs.assert_called_with(
logging_directory=PurePosixPath("hpc_logs"),
aws_secret_access_key=SecretStr("open_data_aws_key"),
aws_access_key_id="open_data_aws_key_id",
aws_default_region="aws_region",
aws_session_token=None,
qos="production",
name="behavior_123456_2020-10-13_13-10-10",
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
Expand Down

0 comments on commit 4e81a81

Please sign in to comment.