From 4e81a81ee7f76ff83ffa3eb9587a233bd77da712 Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:31:59 -0800 Subject: [PATCH] feat: adds ability to write to open-data account (#68) --- src/aind_data_transfer_service/server.py | 22 +++++-- tests/test_server.py | 79 +++++++++++++++++++++++- 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 32a896d..93b57a5 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -43,6 +43,10 @@ # HPC_STAGING_DIRECTORY # HPC_AWS_PARAM_STORE_NAME # BASIC_JOB_SCRIPT +# OPEN_DATA_AWS_SECRET_ACCESS_KEY +# OPEN_DATA_AWS_ACCESS_KEY_ID + +OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "aind-open-data") async def validate_csv(request: Request): @@ -176,6 +180,18 @@ async def submit_hpc_jobs(request: Request): # noqa: C901 job["upload_job_settings"] ).s3_prefix upload_job_configs = json.loads(job["upload_job_settings"]) + # The aws creds to use are different for aind-open-data and + # everything else + if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME: + aws_secret_access_key = SecretStr( + os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY") + ) + aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID") + else: + aws_secret_access_key = SecretStr( + os.getenv("HPC_AWS_SECRET_ACCESS_KEY") + ) + aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID") hpc_settings = json.loads(job["hpc_settings"]) if basic_job_name is not None: hpc_settings["name"] = basic_job_name @@ -183,10 +199,8 @@ async def submit_hpc_jobs(request: Request): # noqa: C901 logging_directory=PurePosixPath( os.getenv("HPC_LOGGING_DIRECTORY") ), - aws_secret_access_key=SecretStr( - os.getenv("HPC_AWS_SECRET_ACCESS_KEY") - ), - aws_access_key_id=os.getenv("HPC_AWS_ACCESS_KEY_ID"), + aws_secret_access_key=aws_secret_access_key, + aws_access_key_id=aws_access_key_id, aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"), aws_session_token=( ( diff --git a/tests/test_server.py b/tests/test_server.py index 45595f2..54c614c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,10 +4,11 @@ import os import unittest from copy import deepcopy -from pathlib import Path +from pathlib import Path, PurePosixPath from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient +from pydantic import SecretStr from requests import Response from aind_data_transfer_service.server import app @@ -41,6 +42,8 @@ class TestServer(unittest.TestCase): "APP_SECRET_KEY": "test_app_key", "HPC_STAGING_DIRECTORY": "/stage/dir", "HPC_AWS_PARAM_STORE_NAME": "/some/param/store", + "OPEN_DATA_AWS_SECRET_ACCESS_KEY": "open_data_aws_key", + "OPEN_DATA_AWS_ACCESS_KEY_ID": "open_data_aws_key_id", } with open(SAMPLE_CSV, "r") as file: @@ -229,7 +232,7 @@ def test_validate_malformed_csv(self): response = client.post(url="/api/validate_csv", files=files) self.assertEqual(response.status_code, 406) self.assertEqual( - [("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")], + ["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"], response.json()["data"]["errors"], ) @@ -244,7 +247,7 @@ def test_validate_malformed_xlsx(self): response = client.post(url="/api/validate_csv", files=files) self.assertEqual(response.status_code, 406) self.assertEqual( - [("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")], + ["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"], response.json()["data"]["errors"], ) @@ -305,6 +308,76 @@ def test_submit_hpc_jobs( self.assertEqual(200, submit_job_response.status_code) self.assertEqual(2, mock_sleep.call_count) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("aind_data_transfer_service.server.sleep", return_value=None) + @patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job") + @patch( + "aind_data_transfer_service.hpc.models.HpcJobSubmitSettings" + ".from_upload_job_configs" + ) + def test_submit_hpc_jobs_open_data( + self, + mock_from_upload_configs: MagicMock, + mock_submit_job: MagicMock, + mock_sleep: MagicMock, + ): + """Tests submit hpc jobs success.""" + mock_response = Response() + mock_response.status_code = 200 + mock_response._content = b'{"message": "success"}' + mock_submit_job.return_value = mock_response + # When a user specifies aind-open-data in the upload_job_settings, + # use the credentials for that account. + post_request_content = { + "jobs": [ + { + "hpc_settings": '{"qos":"production", "name": "job1"}', + "upload_job_settings": ( + '{"s3_bucket": "aind-open-data", ' + '"platform": {"name": "Behavior platform", ' + '"abbreviation": "behavior"}, ' + '"modalities": [' + '{"modality": {"name": "Behavior videos", ' + '"abbreviation": "behavior-videos"}, ' + '"source": "dir/data_set_2", ' + '"compress_raw_data": true, ' + '"skip_staging": false}], ' + '"subject_id": "123456", ' + '"acq_datetime": "2020-10-13T13:10:10", ' + '"process_name": "Other", ' + '"log_level": "WARNING", ' + '"metadata_dir_force": false, ' + '"dry_run": false, ' + '"force_cloud_sync": false}' + ), + "script": "", + } + ] + } + with TestClient(app) as client: + submit_job_response = client.post( + url="/api/submit_hpc_jobs", json=post_request_content + ) + expected_response = { + "message": "Submitted Jobs.", + "data": { + "responses": [{"message": "success"}], + "errors": [], + }, + } + self.assertEqual(expected_response, submit_job_response.json()) + self.assertEqual(200, submit_job_response.status_code) + self.assertEqual(1, mock_sleep.call_count) + mock_from_upload_configs.assert_called_with( + logging_directory=PurePosixPath("hpc_logs"), + aws_secret_access_key=SecretStr("open_data_aws_key"), + aws_access_key_id="open_data_aws_key_id", + aws_default_region="aws_region", + aws_session_token=None, + qos="production", + name="behavior_123456_2020-10-13_13-10-10", + ) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) @patch("aind_data_transfer_service.server.sleep", return_value=None) @patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")