diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 720b81d9b..0da8db384 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -8,6 +8,7 @@ on: push: branches: - main + - dev/main workflow_dispatch: jobs: @@ -31,8 +32,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e . - pip install -r requirements-dev.txt + pip install -e ".[test]" - name: Test run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing diff --git a/CONTRIBUTOR.md b/CONTRIBUTOR.md index 5e37cb8d5..5604ad317 100644 --- a/CONTRIBUTOR.md +++ b/CONTRIBUTOR.md @@ -23,5 +23,6 @@ This will also be run automatically when a PR is made to master and a codecov re ## Branching and PRs - Users that have been added to the CellMap organization and the DaCapo project should be able to develop directly into the CellMap fork of DaCapo. Other users will need to create a fork. -- For a completely new feature, make a branch off of the `main` branch of CellMap's fork of DaCapo with a name describing the feature. If you are collaborating on a feature that already has a branch, you can branch off that feature branch. -- Currently, you should make your PRs into the main branch of CellMap's fork, or the feature branch you branched off of. PRs currently require one maintainer's approval before merging. Once the PR is merged, the feature branch should be deleted. +- For a completely new feature, make a branch off of the `dev/main` branch of CellMap's fork of DaCapo with a name describing the feature. If you are collaborating on a feature that already has a branch, you can branch off that feature branch. +- Currently, you should make your PRs into the `dev/main` branch of CellMap's fork, or the feature branch you branched off of. PRs currently require one maintainer's approval before merging. Once the PR is merged, the feature branch should be deleted. +- `dev/main` will be regularly merged to `main` when new features are fully implemented and all tests are passing. diff --git a/dacapo/apply.py b/dacapo/apply.py index a701d9272..7ee0473fb 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -14,7 +14,7 @@ from dacapo.predict import predict from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray -from dacapo.store import ( +from dacapo.store.create_store import ( create_config_store, create_weights_store, ) @@ -174,8 +174,8 @@ def apply_run( run: Run, parameters: PostProcessorParameters, input_array: Array, - prediction_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", roi: Optional[Roi] = None, num_cpu_workers: int = 30, output_dtype: Optional[np.dtype] = np.uint8, # type: ignore diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index 22bb825d3..a4e23578a 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -72,8 +72,8 @@ def start_worker( def spawn_worker( - input_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + input_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 237062f7e..d4df66ad9 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -170,8 +170,8 @@ def start_worker( def spawn_worker( run_name: str, iteration: int, - raw_array_identifier: LocalArrayIdentifier, - prediction_array_identifier: LocalArrayIdentifier, + raw_array_identifier: "LocalArrayIdentifier", + prediction_array_identifier: "LocalArrayIdentifier", compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index 2f7af5236..929eebbf0 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -73,8 +73,8 @@ def start_worker( def spawn_worker( - input_array_identifier: LocalArrayIdentifier, - output_array_identifier: LocalArrayIdentifier, + input_array_identifier: "LocalArrayIdentifier", + output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, compute_context: ComputeContext = LocalTorch(), ): diff --git a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py index 1782f028e..04b163513 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py @@ -35,7 +35,7 @@ def voxel_size(self) -> Coordinate: @property def roi(self) -> Roi: - return self.crop_roi + return self.crop_roi.intersect(self._source_array.roi) @property def writable(self) -> bool: diff --git a/dacapo/experiments/model.py b/dacapo/experiments/model.py index 8ca2b2b9e..75777cd81 100644 --- a/dacapo/experiments/model.py +++ b/dacapo/experiments/model.py @@ -40,6 +40,13 @@ def __init__( ) self.eval_activation = eval_activation + # UPDATE WEIGHT INITIALIZATION TO USE KAIMING + # TODO: put this somewhere better, there might be + # conv layers that aren't follwed by relus? + for _name, layer in self.named_modules(): + if isinstance(layer, torch.nn.modules.conv._ConvNd): + torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu") + def forward(self, x): result = self.chain(x) if not self.training and self.eval_activation is not None: diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 99bba201f..cb839630d 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -36,7 +36,7 @@ def set_prediction(self, prediction_array_identifier): def process( self, parameters: WatershedPostProcessorParameters, - output_array_identifier: LocalArrayIdentifier, + output_array_identifier: "LocalArrayIdentifier", ): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, diff --git a/dacapo/predict.py b/dacapo/predict.py index 131820572..0e8dfaf45 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -5,11 +5,11 @@ from dacapo.experiments import Run from dacapo.gp import DaCapoArraySource from dacapo.experiments import Model -from dacapo.store import create_config_store -from dacapo.store import create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray, Array +from dacapo.cli import cli from funlib.geometry import Coordinate, Roi import gunpowder as gp diff --git a/dacapo/store/__init__.py b/dacapo/store/__init__.py index 420359ae4..e69de29bb 100644 --- a/dacapo/store/__init__.py +++ b/dacapo/store/__init__.py @@ -1,7 +0,0 @@ -from .converter import converter -from .create_store import ( - create_array_store, - create_config_store, - create_stats_store, - create_weights_store, -) diff --git a/dacapo/store/array_store.py b/dacapo/store/array_store.py index 065196066..7c44ab7ab 100644 --- a/dacapo/store/array_store.py +++ b/dacapo/store/array_store.py @@ -62,7 +62,7 @@ def validation_input_arrays( pass @abstractmethod - def remove(self, array_identifier: LocalArrayIdentifier) -> None: + def remove(self, array_identifier: "LocalArrayIdentifier") -> None: """Remove an array by its identifier.""" pass diff --git a/dacapo/store/file_config_store.py b/dacapo/store/file_config_store.py index 5fbe1ca5c..09f8215cd 100644 --- a/dacapo/store/file_config_store.py +++ b/dacapo/store/file_config_store.py @@ -98,10 +98,12 @@ def __save_insert(self, collection, data, ignore=None): file_store = collection / f"{name}.toml" if not file_store.exists(): - toml.dump(dict(data), file_store.open("w")) + with file_store.open("w") as f: + toml.dump(dict(data), f) else: - existing = toml.load(file_store.open("r")) + with file_store.open("r") as f: + existing = toml.load(f) if not self.__same_doc(existing, data, ignore): raise DuplicateNameError( @@ -113,7 +115,8 @@ def __save_insert(self, collection, data, ignore=None): def __load(self, collection, name): file_store = collection / f"{name}.toml" if file_store.exists(): - return toml.load(file_store.open("r")) + with file_store.open("r") as f: + return toml.load(f) else: raise ValueError(f"No config with name: {name} in collection: {collection}") diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index c1581fc7b..73994d980 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -85,7 +85,7 @@ def validation_container(self, run_name: str) -> LocalContainerIdentifier: Path(self.__get_run_dir(run_name), "validation.zarr") ) - def remove(self, array_identifier: LocalArrayIdentifier) -> None: + def remove(self, array_identifier: "LocalArrayIdentifier") -> None: container = array_identifier.container dataset = array_identifier.dataset diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index d2b0fa02d..28adacdac 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -97,7 +97,12 @@ def store_best(self, run: str, iteration: int, dataset: str, criterion: str): if best_weights.exists(): best_weights.unlink() - best_weights.symlink_to(iteration_weights) + try: + best_weights.symlink_to(iteration_weights) + except FileExistsError: + best_weights.unlink() + best_weights.symlink_to(iteration_weights) + with best_weights_json.open("w") as f: f.write(json.dumps({"iteration": iteration})) diff --git a/dacapo/train.py b/dacapo/train.py index 7beb096b4..04642e4d3 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,8 +1,12 @@ -from dacapo.store.create_store import create_array_store -from .experiments import Run -from .compute_context import LocalTorch, ComputeContext -from .store import create_config_store, create_stats_store, create_weights_store -from .validate import validate_run +from dacapo.store.create_store import ( + create_array_store, + create_config_store, + create_stats_store, + create_weights_store, +) +from dacapo.experiments import Run +from dacapo.compute_context import LocalTorch, ComputeContext +from dacapo.validate import validate_run import torch from tqdm import tqdm diff --git a/dacapo/validate.py b/dacapo/validate.py index a1cf9da7d..348549f32 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -2,7 +2,7 @@ from .compute_context import LocalTorch, ComputeContext from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray -from .store import ( +from .store.create_store import ( create_array_store, create_config_store, create_stats_store, diff --git a/pyproject.toml b/pyproject.toml index 21512e15b..18af7fdfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,9 @@ dependencies = [ "xarray", "cattrs", "numpy-indexed", - "click",] + "click", + "toml", + ] # extras # https://peps.python.org/pep-0621/#dependencies-optional-dependencies @@ -236,4 +238,4 @@ ignore = [ ".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*", -] \ No newline at end of file +] diff --git a/tests/components/test_arrays.py b/tests/components/test_arrays.py index b81a4bc7e..d62dcb973 100644 --- a/tests/components/test_arrays.py +++ b/tests/components/test_arrays.py @@ -1,6 +1,6 @@ from ..fixtures import * -from dacapo.store import create_config_store +from dacapo.store.create_store import create_config_store import pytest from pytest_lazyfixture import lazy_fixture diff --git a/tests/components/test_trainers.py b/tests/components/test_trainers.py index 8ef792ba5..172a89b75 100644 --- a/tests/components/test_trainers.py +++ b/tests/components/test_trainers.py @@ -1,6 +1,6 @@ from ..fixtures import * -from dacapo.store import create_config_store +from dacapo.store.create_store import create_config_store import pytest from pytest_lazyfixture import lazy_fixture diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 9a633a90a..53ca30b7f 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -2,7 +2,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import apply import pytest diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index abecfd9c3..846afe6c3 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -3,7 +3,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.train import train_run import pytest diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index d18a87197..54d6dc5e4 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -2,7 +2,7 @@ from dacapo.experiments import Run from dacapo.compute_context import LocalTorch -from dacapo.store import create_config_store, create_weights_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import validate import pytest