Skip to content

Commit

Permalink
feat: use Path or str type (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 authored Oct 25, 2024
1 parent d7094e2 commit 472743c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "aind-data-transformation"
description = "Generated from aind-library-template"
description = "Generic Etl Job template that can be imported"
license = {text = "MIT"}
requires-python = ">=3.8"
authors = [
Expand Down
17 changes: 14 additions & 3 deletions src/aind_data_transformation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pydantic import BaseModel, ConfigDict, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

PathLike = TypeVar("PathLike", str, Path)


def get_parser() -> argparse.ArgumentParser:
"""
Expand Down Expand Up @@ -49,8 +51,8 @@ class BasicJobSettings(BaseSettings):
"""Model to define Transformation Job Configs"""

model_config = SettingsConfigDict(env_prefix="TRANSFORMATION_JOB_")
input_source: Path
output_directory: Path
input_source: PathLike
output_directory: PathLike

@classmethod
def from_config_file(cls, config_file_location: Path):
Expand Down Expand Up @@ -92,7 +94,16 @@ def __init__(self, job_settings: _T):
job_settings : _T
Generic type that is bound by the BaseSettings class.
"""
self.job_settings = job_settings
self.job_settings = job_settings.model_copy(deep=True)
# Parse str into Paths
if isinstance(self.job_settings.input_source, str):
self.job_settings.input_source = Path(
self.job_settings.input_source
)
if isinstance(self.job_settings.output_directory, str):
self.job_settings.output_directory = Path(
self.job_settings.output_directory
)

@abstractmethod
def run_job(self) -> JobResponse:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,22 @@ def setUpClass(cls) -> None:
"""Set up tests with basic job settings and etl job"""
basic_settings = ExampleJobSettings(
param=2,
input_source=Path("some_input_dir"),
output_directory=Path("some_output_dir"),
input_source="some_input_dir",
output_directory="some_output_dir",
)
cls.basic_settings = basic_settings
cls.basic_job = ExampleJob(job_settings=basic_settings)

def test_settings_with_paths(self):
"""Tests JobSettings can be set with Path types if desired."""
basic_settings = ExampleJobSettings(
param=2,
input_source=Path("some_input_dir"),
output_directory=Path("some_out_dir"),
)
self.assertEqual(Path("some_input_dir"), basic_settings.input_source)
self.assertEqual(Path("some_out_dir"), basic_settings.output_directory)

def test_load_cli_args_json_str(self):
"""Tests loading json string defined in command line args"""
job_settings_json = self.basic_settings.model_dump_json()
Expand Down

0 comments on commit 472743c

Please sign in to comment.