diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 1e20885bf..8f4326d55 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -52,9 +52,9 @@ def axes(self): logger.debug( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" - f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", + f"Using default {['s', 'c', 'z', 'y', 'x'][-self.dims::]}", ) - return ["c", "z", "y", "x"][-self.dims : :] + return ["s", "c", "z", "y", "x"][-self.dims : :] @property def dims(self) -> int: @@ -149,26 +149,33 @@ def create_from_array_identifier( delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] - zarr_dataset.attrs["offset"] = ( - roi.offset[::-1] - if array_identifier.container.name.endswith("n5") - else roi.offset - ) - zarr_dataset.attrs["resolution"] = ( - voxel_size[::-1] - if array_identifier.container.name.endswith("n5") - else voxel_size - ) - zarr_dataset.attrs["axes"] = ( - axes[::-1] if array_identifier.container.name.endswith("n5") else axes - ) - # to make display right in neuroglancer: TODO - zarr_dataset.attrs["dimension_units"] = [ - f"{size} nm" for size in zarr_dataset.attrs["resolution"] - ] - zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = ( - axes[::-1] if array_identifier.container.name.endswith("n5") else axes - ) + if array_identifier.container.name.endswith("n5"): + zarr_dataset.attrs["offset"] = roi.offset[::-1] + zarr_dataset.attrs["resolution"] = voxel_size[::-1] + zarr_dataset.attrs["axes"] = axes[::-1] + # to make display right in neuroglancer: TODO ADD CHANNELS + zarr_dataset.attrs["dimension_units"] = [ + f"{size} nm" for size in voxel_size[::-1] + ] + zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = axes[::-1] + else: + zarr_dataset.attrs["offset"] = roi.offset + zarr_dataset.attrs["resolution"] = voxel_size + zarr_dataset.attrs["axes"] = axes + # to make display right in neuroglancer: TODO ADD CHANNELS + zarr_dataset.attrs["dimension_units"] = [ + f"{size} nm" for size in voxel_size + ] + zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = axes + if num_channels is not None: + if axes.index("c") == 0: + zarr_dataset.attrs["dimension_units"] = [ + num_channels + ] + zarr_dataset.attrs["dimension_units"] + else: + zarr_dataset.attrs["dimension_units"] = zarr_dataset.attrs[ + "dimension_units" + ] + [num_channels] except zarr.errors.ContainsArrayError: zarr_dataset = zarr_container[array_identifier.dataset] assert ( diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 42863b56d..24146c5c7 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -39,6 +39,7 @@ def process( None, self.prediction_array.voxel_size, np.uint8, + block_size, ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 5d3b45220..5c32801c2 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -51,6 +51,7 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, + block_size, ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index fa9d10a47..0a0b1e0a3 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -46,6 +46,7 @@ def process( None, self.prediction_array.voxel_size, np.uint64, + block_size, ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 57166beb3..10ea1603d 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -260,6 +260,7 @@ def iterate(self, num_iterations, model, optimizer, device): v.num_channels, v.voxel_size, v.dtype if not v.dtype == bool else np.float32, + model.output_shape, ) dataset = snapshot_zarr[k] else: diff --git a/dacapo/predict.py b/dacapo/predict.py index 4a5fa9ebc..90b071470 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -120,6 +120,7 @@ def predict( output_voxel_size, output_dtype, overwrite=overwrite, + write_size=output_size, ) # run blockwise prediction diff --git a/dacapo/train.py b/dacapo/train.py index 6a0d00d54..4c84f63d2 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -170,10 +170,6 @@ def train_run(run: Run): stats_store.store_training_stats(run.name, run.training_stats) continue - run.model.eval() - # free up optimizer memory to allow larger validation blocks - run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) - stats_store.store_training_stats(run.name, run.training_stats) weights_store.store_weights(run, iteration_stats.iteration + 1) try: @@ -200,8 +196,4 @@ def train_run(run: Run): exc_info=e, ) - # make sure to move optimizer back to the correct device - run.move_optimizer(compute_context.device) - run.model.train() - print("Trained until %d, finished.", trained_until) diff --git a/dacapo/validate.py b/dacapo/validate.py index e1d6065a5..aa2b2a676 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -228,6 +228,7 @@ def validate_run( post_processed_array.num_channels, post_processed_array.voxel_size, post_processed_array.dtype, + output_size, ) best_array[best_array.roi] = post_processed_array[ post_processed_array.roi diff --git a/pyproject.toml b/pyproject.toml index 2cb13a424..49b7fba29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,8 @@ examples = [ ] pretrained = [ "pyqt5", - "empanada-napari", + "empanada-napari", + "cellmap-models", ] all = ["dacapo-ml[test,dev,docs,examples,pretrained]"]