diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 720b81d9b..0da8db384 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -8,6 +8,7 @@ on: push: branches: - main + - dev/main workflow_dispatch: jobs: @@ -31,8 +32,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e . - pip install -r requirements-dev.txt + pip install -e ".[test]" - name: Test run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing diff --git a/dacapo/apply.py b/dacapo/apply.py index a701d9272..7ee0473fb 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -14,7 +14,7 @@ from dacapo.predict import predict from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray -from dacapo.store import ( +from dacapo.store.create_store import ( create_config_store, create_weights_store, ) @@ -174,8 +174,8 @@ def apply_run( run: Run, parameters: PostProcessorParameters, input_array: Array, - prediction_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", roi: Optional[Roi] = None, num_cpu_workers: int = 30, output_dtype: Optional[np.dtype] = np.uint8, # type: ignore diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index 22bb825d3..a4e23578a 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -72,8 +72,8 @@ def start_worker( def spawn_worker( - input_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + input_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 237062f7e..d4df66ad9 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -170,8 +170,8 @@ def start_worker( def spawn_worker( run_name: str, iteration: int, - raw_array_identifier: LocalArrayIdentifier, - prediction_array_identifier: LocalArrayIdentifier, + raw_array_identifier: "LocalArrayIdentifier", + prediction_array_identifier: "LocalArrayIdentifier", compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index 2f7af5236..929eebbf0 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -73,8 +73,8 @@ def start_worker( def spawn_worker( - input_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + input_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, compute_context: ComputeContext = LocalTorch(), ): diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 99bba201f..cb839630d 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -36,7 +36,7 @@ def set_prediction(self, prediction_array_identifier): def process( self, parameters: WatershedPostProcessorParameters, - output_array_identifier: LocalArrayIdentifier, + output_array_identifier: "LocalArrayIdentifier", ): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, diff --git a/dacapo/predict.py b/dacapo/predict.py index 131820572..0e8dfaf45 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -5,11 +5,11 @@ from dacapo.experiments import Run from dacapo.gp import DaCapoArraySource from dacapo.experiments import Model -from dacapo.store import create_config_store -from dacapo.store import create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray, Array +from dacapo.cli import cli from funlib.geometry import Coordinate, Roi import gunpowder as gp diff --git a/dacapo/store/__init__.py b/dacapo/store/__init__.py index 420359ae4..e69de29bb 100644 --- a/dacapo/store/__init__.py +++ b/dacapo/store/__init__.py @@ -1,7 +0,0 @@ -from .converter import converter -from .create_store import ( - create_array_store, - create_config_store, - create_stats_store, - create_weights_store, -) diff --git a/dacapo/store/array_store.py b/dacapo/store/array_store.py index 065196066..7c44ab7ab 100644 --- a/dacapo/store/array_store.py +++ b/dacapo/store/array_store.py @@ -62,7 +62,7 @@ def validation_input_arrays( pass @abstractmethod - def remove(self, array_identifier: LocalArrayIdentifier) -> None: + def remove(self, array_identifier: "LocalArrayIdentifier") -> None: """Remove an array by its identifier.""" pass diff --git a/dacapo/store/file_config_store.py b/dacapo/store/file_config_store.py index 5fbe1ca5c..09f8215cd 100644 --- a/dacapo/store/file_config_store.py +++ b/dacapo/store/file_config_store.py @@ -98,10 +98,12 @@ def __save_insert(self, collection, data, ignore=None): file_store = collection / f"{name}.toml" if not file_store.exists(): - toml.dump(dict(data), file_store.open("w")) + with file_store.open("w") as f: + toml.dump(dict(data), f) else: - existing = toml.load(file_store.open("r")) + with file_store.open("r") as f: + existing = toml.load(f) if not self.__same_doc(existing, data, ignore): raise DuplicateNameError( @@ -113,7 +115,8 @@ def __save_insert(self, collection, data, ignore=None): def __load(self, collection, name): file_store = collection / f"{name}.toml" if file_store.exists(): - return toml.load(file_store.open("r")) + with file_store.open("r") as f: + return toml.load(f) else: raise ValueError(f"No config with name: {name} in collection: {collection}") diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index c1581fc7b..73994d980 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -85,7 +85,7 @@ def validation_container(self, run_name: str) -> LocalContainerIdentifier: Path(self.__get_run_dir(run_name), "validation.zarr") ) - def remove(self, array_identifier: LocalArrayIdentifier) -> None: + def remove(self, array_identifier: "LocalArrayIdentifier") -> None: container = array_identifier.container dataset = array_identifier.dataset diff --git a/dacapo/train.py b/dacapo/train.py index 7beb096b4..04642e4d3 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,8 +1,12 @@ -from dacapo.store.create_store import create_array_store -from .experiments import Run -from .compute_context import LocalTorch, ComputeContext -from .store import create_config_store, create_stats_store, create_weights_store -from .validate import validate_run +from dacapo.store.create_store import ( + create_array_store, + create_config_store, + create_stats_store, + create_weights_store, +) +from dacapo.experiments import Run +from dacapo.compute_context import LocalTorch, ComputeContext +from dacapo.validate import validate_run import torch from tqdm import tqdm diff --git a/dacapo/validate.py b/dacapo/validate.py index a1cf9da7d..348549f32 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -2,7 +2,7 @@ from .compute_context import LocalTorch, ComputeContext from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray -from .store import ( +from .store.create_store import ( create_array_store, create_config_store, create_stats_store, diff --git a/tests/components/test_arrays.py b/tests/components/test_arrays.py index b81a4bc7e..d62dcb973 100644 --- a/tests/components/test_arrays.py +++ b/tests/components/test_arrays.py @@ -1,6 +1,6 @@ from ..fixtures import * -from dacapo.store import create_config_store +from dacapo.store.create_store import create_config_store import pytest from pytest_lazyfixture import lazy_fixture diff --git a/tests/components/test_trainers.py b/tests/components/test_trainers.py index 8ef792ba5..172a89b75 100644 --- a/tests/components/test_trainers.py +++ b/tests/components/test_trainers.py @@ -1,6 +1,6 @@ from ..fixtures import * -from dacapo.store import create_config_store +from dacapo.store.create_store import create_config_store import pytest from pytest_lazyfixture import lazy_fixture diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 9a633a90a..53ca30b7f 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -2,7 +2,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import apply import pytest diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index abecfd9c3..846afe6c3 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -3,7 +3,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.train import train_run import pytest diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index d18a87197..54d6dc5e4 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -2,7 +2,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import validate import pytest