From a66cd4c40dc33f6a1176970b02defe98cd9848c2 Mon Sep 17 00:00:00 2001 From: jtyoung84 <104453205+jtyoung84@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:44:43 -0800 Subject: [PATCH] fix: updates parsing of modality (#72) --- .../configs/job_configs.py | 9 +++-- tests/resources/sample_alt_modality_case.csv | 4 ++ tests/test_configs.py | 37 +++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 tests/resources/sample_alt_modality_case.csv diff --git a/src/aind_data_transfer_service/configs/job_configs.py b/src/aind_data_transfer_service/configs/job_configs.py index 881267a..776ed5e 100644 --- a/src/aind_data_transfer_service/configs/job_configs.py +++ b/src/aind_data_transfer_service/configs/job_configs.py @@ -10,11 +10,12 @@ from aind_data_schema.models.modalities import Modality from aind_data_schema.models.platforms import Platform from pydantic import ( + ConfigDict, Field, PrivateAttr, SecretStr, ValidationInfo, - field_validator, ConfigDict, + field_validator, ) from pydantic_settings import BaseSettings @@ -77,7 +78,7 @@ def parse_modality_string( if unable to do so.""" if isinstance(input_modality, str): modality_abbreviation = cls._MODALITY_MAP.get( - input_modality.upper() + input_modality.upper().replace("-", "_") ) if modality_abbreviation is None: raise AttributeError(f"Unknown Modality: {input_modality}") @@ -105,7 +106,9 @@ class BasicUploadJobConfigs(BaseSettings): """Configuration for the basic upload job""" # Allow users to pass in extra fields - model_config = ConfigDict(extra='allow',) + model_config = ConfigDict( + extra="allow", + ) # Need some way to extract abbreviations. Maybe a public method can be # added to the Platform class diff --git a/tests/resources/sample_alt_modality_case.csv b/tests/resources/sample_alt_modality_case.csv new file mode 100644 index 0000000..b23dcad --- /dev/null +++ b/tests/resources/sample_alt_modality_case.csv @@ -0,0 +1,4 @@ +modality0, modality0.source, modality1, modality1.source, s3-bucket, subject-id, platform, acq-datetime +ecephys, dir/data_set_1, ,, some_bucket, 123454, ecephys, 2020-10-10 14:10:10 +behavior-videos, dir/data_set_2, MRI, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM +behavior-videos, dir/data_set_2, BEHAVIOR_VIDEOS, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM diff --git a/tests/test_configs.py b/tests/test_configs.py index 39652b0..deda2c0 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -18,6 +18,7 @@ RESOURCES_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / "resources" SAMPLE_FILE = RESOURCES_DIR / "sample.csv" +SAMPLE_ALT_MODALITY_CASE_FILE = RESOURCES_DIR / "sample_alt_modality_case.csv" class TestJobConfigs(unittest.TestCase): @@ -171,6 +172,42 @@ def test_parse_csv_file(self): ), ) + def test_parse_alt_csv_file(self): + """Tests that the jobs can be parsed from a csv file correctly where + the modalities are lower case.""" + + jobs = [] + + with open(SAMPLE_ALT_MODALITY_CASE_FILE, newline="") as csvfile: + reader = csv.DictReader(csvfile, skipinitialspace=True) + for row in reader: + jobs.append( + BasicUploadJobConfigs.from_csv_row( + row, aws_param_store_name="/some/param/store" + ) + ) + + modality_outputs = [] + for job in jobs: + job_s3_prefix = job.s3_prefix + for modality in job.modalities: + modality_outputs.append( + ( + job_s3_prefix, + modality.default_output_folder_name, + modality.number_id, + ) + ) + expected_modality_outputs = [ + ("ecephys_123454_2020-10-10_14-10-10", "ecephys", None), + ("behavior_123456_2020-10-13_13-10-10", "behavior-videos", None), + ("behavior_123456_2020-10-13_13-10-10", "MRI", None), + ("behavior_123456_2020-10-13_13-10-10", "behavior-videos", None), + ("behavior_123456_2020-10-13_13-10-10", "behavior-videos1", 1), + ] + self.assertEqual(self.expected_job_configs, jobs) + self.assertEqual(expected_modality_outputs, modality_outputs) + def test_malformed_platform(self): """Tests that an error is raised if an unknown platform is used"""