From 33bbc8ae00cc3c2d6fba3d71182991e5ee89e0a5 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:08:37 +0000 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=F0=9F=9A=A7=20Incorporate=20simple?= =?UTF-8?q?=20change=20from=20rhoadesj/dev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/apply.py | 196 +++++++++++++++++- dacapo/cli.py | 55 ++++- dacapo/experiments/tasks/affinities_task.py | 6 +- .../tasks/affinities_task_config.py | 20 ++ .../tasks/predictors/affinities_predictor.py | 33 ++- dacapo/experiments/training_stats.py | 4 +- dacapo/predict.py | 29 ++- dacapo/train.py | 53 +++-- setup.py | 8 +- 9 files changed, 348 insertions(+), 56 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 64f23df3c..b33cffe46 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,12 +1,200 @@ import logging +from typing import Optional +from funlib.geometry import Roi, Coordinate +import numpy as np +from dacapo.experiments.datasplits.datasets.arrays.array import Array +from dacapo.experiments.datasplits.datasets.dataset import Dataset +from dacapo.experiments.run import Run + +from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( + PostProcessorParameters, +) +import dacapo.experiments.tasks.post_processors as post_processors +from dacapo.store.array_store import LocalArrayIdentifier +from dacapo.predict import predict +from dacapo.compute_context import LocalTorch, ComputeContext +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.store import ( + create_config_store, + create_weights_store, +) + +from pathlib import Path logger = logging.getLogger(__name__) -def apply(run_name: str, iteration: int, dataset_name: str): +def apply( + run_name: str, + input_container: Path or str, + input_dataset: str, + output_path: Path or str, + validation_dataset: Optional[Dataset or str] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + parameters: Optional[PostProcessorParameters or str] = None, + roi: Optional[Roi or str] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[np.dtype or str] = np.uint8, + compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, + file_format: str = "zarr", +): + """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + if isinstance(roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in roi.strip("[]").split(",") + ] + ) + roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + + assert (validation_dataset is not None and isinstance(criterion, str)) or ( + isinstance(iteration, int) + ), "Either validation_dataset and criterion, or iteration must be provided." + + # retrieving run + logger.info("Loading run %s", run_name) + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # create weights store + weights_store = create_weights_store() + + # load weights + if iteration is None: + # weights_store._load_best(run, criterion) + iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) + logger.info("Loading weights for iteration %i", iteration) + weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights? + + # find the best parameters + if isinstance(validation_dataset, str): + val_ds_name = validation_dataset + validation_dataset = [ + dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name + ][0] + logger.info("Finding best parameters for validation dataset %s", validation_dataset) + if parameters is None: + parameters = run.task.evaluator.get_overall_best_parameters( + validation_dataset, criterion + ) + assert ( + parameters is not None + ), "Unable to retieve parameters. Parameters must be provided explicitly." + + elif isinstance(parameters, str): + try: + post_processor_name = parameters.split("(")[0] + post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") + post_processor_kwargs = { + key.strip(): value.strip() + for key, value in [arg.split("=") for arg in post_processor_kwargs] + } + for key, value in post_processor_kwargs.items(): + if value.isdigit(): + post_processor_kwargs[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + post_processor_kwargs[key] = float(value) + except: + raise ValueError( + f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'" + ) + try: + parameters = getattr(post_processors, post_processor_name)( + **post_processor_kwargs + ) + except Exception as e: + logger.error( + f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.", + exc_info=True, + ) + raise e + + assert isinstance( + parameters, PostProcessorParameters + ), "Parameters must be parsable to a PostProcessorParameters object." + + # make array identifiers for input, predictions and outputs + input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) + input_array = ZarrArray.open_from_array_identifier(input_array_identifier) + roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( + input_array.roi + ) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", + ) + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + output_array_identifier = LocalArrayIdentifier( + output_container, f"output_{run_name}_{iteration}_{parameters}" + ) + logger.info( - "Applying results from run %s at iteration %d to dataset %s", - run_name, + "Applying best results from run %s at iteration %i to dataset %s", + run.name, iteration, - dataset_name, + Path(input_container, input_dataset), + ) + return apply_run( + run, + parameters, + input_array, + prediction_array_identifier, + output_array_identifier, + roi, + num_cpu_workers, + output_dtype, + compute_context, + overwrite, + ) + + +def apply_run( + run: Run, + parameters: PostProcessorParameters, + input_array: Array, + prediction_array_identifier: LocalArrayIdentifier, + output_array_identifier: LocalArrayIdentifier, + roi: Optional[Roi] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[np.dtype] = np.uint8, + compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, +): + """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" + run.model.eval() + + # render prediction dataset + logger.info("Predicting on dataset %s", prediction_array_identifier) + predict( + run.model, + input_array, + prediction_array_identifier, + output_roi=roi, + num_cpu_workers=num_cpu_workers, + output_dtype=output_dtype, + compute_context=compute_context, + overwrite=overwrite, ) + + # post-process the output + logger.info("Post-processing output to dataset %s", output_array_identifier) + post_processor = run.task.post_processor + post_processor.set_prediction(prediction_array_identifier) + post_processor.process( + parameters, output_array_identifier, overwrite=overwrite, blockwise=True + ) + + logger.info("Done") + return diff --git a/dacapo/cli.py b/dacapo/cli.py index 76a5e18e0..f8f06db54 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,3 +1,5 @@ +from typing import Optional + import dacapo import click import logging @@ -40,21 +42,52 @@ def validate(run_name, iteration): @cli.command() @click.option( - "-r", "--run", required=True, type=str, help="The name of the run to use." + "-r", "--run_name", required=True, type=str, help="The name of the run to use." ) @click.option( - "-i", - "--iteration", + "-ic", + "--input_container", required=True, - type=int, - help="The iteration weights and parameters to use.", + type=click.Path(exists=True, file_okay=False), ) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) +@click.option("-vd", "--validation_dataset", type=str, default=None) +@click.option("-c", "--criterion", default="voi") +@click.option("-i", "--iteration", type=int, default=None) +@click.option("-p", "--parameters", type=str, default=None) @click.option( - "-r", - "--dataset", - required=True, + "-roi", + "--roi", type=str, - help="The name of the dataset to apply the run to.", + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -def apply(run_name, iteration, dataset_name): - dacapo.apply(run_name, iteration, dataset_name) +@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +def apply( + run_name: str, + input_container: str, + input_dataset: str, + output_path: str, + validation_dataset: Optional[str] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + parameters: Optional[str] = None, + roi: Optional[str] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[str] = "uint8", +): + dacapo.apply( + run_name, + input_container, + input_dataset, + output_path, + validation_dataset, + criterion, + iteration, + parameters, + roi, + num_cpu_workers, + output_dtype, + ) diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index c1014fd02..4a1b8cc4a 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -12,7 +12,11 @@ def __init__(self, task_config): """Create a `DummyTask` from a `DummyTaskConfig`.""" self.predictor = AffinitiesPredictor( - neighborhood=task_config.neighborhood, lsds=task_config.lsds + neighborhood=task_config.neighborhood, + lsds=task_config.lsds, + num_voxels=task_config.num_voxels, + downsample_lsds=task_config.downsample_lsds, + grow_boundary_iterations=task_config.grow_boundary_iterations, ) self.loss = AffinitiesLoss(len(task_config.neighborhood)) self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood) diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index d4b2c6199..0a94db79d 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -30,3 +30,23 @@ class AffinitiesTaskConfig(TaskConfig): "It has been shown that lsds as an auxiliary task can help affinity predictions." }, ) + num_voxels: int = attr.ib( + default=20, + metadata={ + "help_text": "The number of voxels to use for the gaussian sigma when computing lsds." + }, + ) + downsample_lsds: int = attr.ib( + default=1, + metadata={ + "help_text": "The amount to downsample the lsds. " + "This is useful for speeding up training and inference." + }, + ) + grow_boundary_iterations: int = attr.ib( + default=0, + metadata={ + "help_text": "The number of iterations to run the grow boundaries algorithm. " + "This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects." + }, + ) diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 81efb2375..40d81f5da 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -17,9 +17,17 @@ class AffinitiesPredictor(Predictor): - def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): + def __init__( + self, + neighborhood: List[Coordinate], + lsds: bool = True, + num_voxels: int = 20, + downsample_lsds: int = 1, + grow_boundary_iterations: int = 0, + ): self.neighborhood = neighborhood self.lsds = lsds + self.num_voxels = num_voxels if lsds: self._extractor = None if self.dims == 2: @@ -30,12 +38,16 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): raise ValueError( f"Cannot compute lsds on volumes with {self.dims} dimensions" ) + self.downsample_lsds = downsample_lsds else: self.num_lsds = 0 + self.grow_boundary_iterations = grow_boundary_iterations def extractor(self, voxel_size): if self._extractor is None: - self._extractor = LsdExtractor(self.sigma(voxel_size)) + self._extractor = LsdExtractor( + self.sigma(voxel_size), downsample=self.downsample_lsds + ) return self._extractor @@ -45,8 +57,7 @@ def dims(self): def sigma(self, voxel_size): voxel_dist = max(voxel_size) # arbitrarily chosen - num_voxels = 10 # arbitrarily chosen - sigma = voxel_dist * num_voxels + sigma = voxel_dist * self.num_voxels # arbitrarily chosen return Coordinate((sigma,) * self.dims) def lsd_pad(self, voxel_size): @@ -118,7 +129,9 @@ def _grow_boundaries(self, mask, slab): slice(start[d], start[d] + slab[d]) for d in range(len(slab)) ) mask_slab = mask[slices] - dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1) + dilated_mask_slab = ndimage.binary_dilation( + mask_slab, iterations=self.grow_boundary_iterations + ) foreground[slices] = dilated_mask_slab # label new background @@ -130,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): (moving_class_counts, moving_lsd_class_counts) = ( moving_class_counts if moving_class_counts is not None else (None, None) ) - # mask_data = self._grow_boundaries( - # mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) - # ) - mask_data = mask[target.roi] + if self.grow_boundary_iterations > 0: + mask_data = self._grow_boundaries( + mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) + ) + else: + mask_data = mask[target.roi] aff_weights, moving_class_counts = balance_weights( target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8), 2, diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index cd3fcd012..72c631ed4 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -16,7 +16,9 @@ class TrainingStats: def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: if len(self.iteration_stats) > 0: - assert iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + assert ( + iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + ), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}" self.iteration_stats.append(iteration_stats) diff --git a/dacapo/predict.py b/dacapo/predict.py index 5a40e303c..07483bea1 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,6 +24,8 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, + output_dtype: Optional[np.dtype] = np.uint8, + overwrite: bool = False, ): # get the model's input and output size @@ -56,7 +58,8 @@ def predict( output_roi, model.num_out_channels, output_voxel_size, - np.float32, + output_dtype, + overwrite=overwrite, ) # create gunpowder keys @@ -68,6 +71,7 @@ def predict( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) + pipeline += gp.Normalize(raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) @@ -75,8 +79,8 @@ def predict( # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) - + prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? + # TODO: Add cache node? # predict pipeline += gp_torch.Predict( model=model, @@ -84,7 +88,9 @@ def predict( outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( - roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 + roi=prediction_roi, + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 ) }, spawn_subprocess=False, @@ -97,22 +103,29 @@ def predict( pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) - # raw: (c, d, h, w) - # prediction: (c, d, h, w) + + # convert to uint8 if necessary: + if output_dtype == np.uint8: + pipeline += gp.IntensityScaleShift( + prediction, scale=255.0, shift=0.0 + ) # assumes float32 is [0,1] + pipeline += gp.AsType(prediction, output_dtype) # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, - dataset_dtypes={prediction: np.float32}, + dataset_dtypes={prediction: output_dtype}, ) # create reference batch request ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) - pipeline += gp.Scan(ref_request) + pipeline += gp.Scan( + ref_request + ) # TODO: This is a slow implementation for rendering # build pipeline and predict in complete output ROI diff --git a/dacapo/train.py b/dacapo/train.py index e8667d8b8..1c104a55f 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dacapo.store.create_store import create_array_store from .experiments import Run from .compute_context import LocalTorch, ComputeContext @@ -10,6 +11,7 @@ import logging logger = logging.getLogger(__name__) +logger.setLevel("INFO") def train(run_name: str, compute_context: ComputeContext = LocalTorch()): @@ -100,8 +102,17 @@ def train_run( weights_store.retrieve_weights(run, iteration=latest_weights_iteration) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " - f"run {run.name} was only trained until {trained_until}." + f"run {run.name} was only trained until {trained_until}. " + "Filling stats with last observed values." ) + last_iteration_stats = run.training_stats.iteration_stats[-1] + for i in range( + last_iteration_stats.iteration, latest_weights_iteration - 1 + ): + new_iteration_stats = deepcopy(last_iteration_stats) + new_iteration_stats.iteration = i + 1 + run.training_stats.add_iteration_stats(new_iteration_stats) + trained_until = run.training_stats.trained_until() # start/resume training @@ -129,18 +140,20 @@ def train_run( # train for at most 100 iterations at a time, then store training stats iterations = min(100, run.train_until - trained_until) iteration_stats = None - - for iteration_stats in tqdm( + bar = tqdm( trainer.iterate( iterations, run.model, run.optimizer, compute_context.device, ), - "training", - iterations, - ): + desc=f"training until {iterations + trained_until}", + total=run.train_until, + initial=trained_until, + ) + for iteration_stats in bar: run.training_stats.add_iteration_stats(iteration_stats) + bar.set_postfix({"loss": iteration_stats.loss}) if (iteration_stats.iteration + 1) % run.validation_interval == 0: break @@ -162,22 +175,26 @@ def train_run( run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) - weights_store.store_weights(run, iteration_stats.iteration + 1) - validate_run( - run, - iteration_stats.iteration + 1, - compute_context=compute_context, - ) - stats_store.store_validation_iteration_scores( - run.name, run.validation_scores - ) stats_store.store_training_stats(run.name, run.training_stats) + weights_store.store_weights(run, iteration_stats.iteration + 1) + try: + validate_run( + run, + iteration_stats.iteration + 1, + compute_context=compute_context, + ) + stats_store.store_validation_iteration_scores( + run.name, run.validation_scores + ) + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " + f"{iteration_stats.iteration + 1}.", + exc_info=e, + ) # make sure to move optimizer back to the correct device run.move_optimizer(compute_context.device) run.model.train() - weights_store.store_weights(run, run.training_stats.trained_until()) - stats_store.store_training_stats(run.name, run.training_stats) - logger.info("Trained until %d, finished.", trained_until) diff --git a/setup.py b/setup.py index b38a41edd..34faf365b 100644 --- a/setup.py +++ b/setup.py @@ -5,16 +5,16 @@ description="Framework for easy composition of volumetric machine learning jobs.", long_description=open("README.md", "r").read(), long_description_content_type="text/markdown", - version="0.1", + version="0.1.1", url="https://github.com/funkelab/dacapo", - author="Jan Funke, Will Patton", - author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org", + author="Jan Funke, Will Patton, Jeff Rhoades", + author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org, rhoadesj@hhmi.org", license="MIT", packages=find_packages(), entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]}, include_package_data=True, install_requires=[ - "numpy", + "numpy==1.22.3", "pyyaml", "zarr", "cattrs", From fe23b5d887d0ff325f5f54cc940a6a219f86be55 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:09:22 +0000 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20=F0=9F=9A=A7=20Incorporate=20simple?= =?UTF-8?q?=20change=20from=20rhoadesj/dev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/cli.py | 2 +- .../datasplits/datasets/arrays/zarr_array.py | 2 +- dacapo/predict.py | 3 +-- dacapo/train.py | 11 ----------- 4 files changed, 3 insertions(+), 15 deletions(-) diff --git a/dacapo/cli.py b/dacapo/cli.py index f8f06db54..f97906508 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -42,7 +42,7 @@ def validate(run_name, iteration): @cli.command() @click.option( - "-r", "--run_name", required=True, type=str, help="The name of the run to use." + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." ) @click.option( "-ic", diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 42030e701..25f2c224e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -52,7 +52,7 @@ def axes(self): logger.debug( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" - f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", + f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", ) return ["c", "z", "y", "x"][-self.dims : :] diff --git a/dacapo/predict.py b/dacapo/predict.py index 07483bea1..340517528 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,7 +24,7 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, - output_dtype: Optional[np.dtype] = np.uint8, + output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions overwrite: bool = False, ): # get the model's input and output size @@ -71,7 +71,6 @@ def predict( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) - pipeline += gp.Normalize(raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) diff --git a/dacapo/train.py b/dacapo/train.py index 1c104a55f..7beb096b4 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,4 +1,3 @@ -from copy import deepcopy from dacapo.store.create_store import create_array_store from .experiments import Run from .compute_context import LocalTorch, ComputeContext @@ -11,7 +10,6 @@ import logging logger = logging.getLogger(__name__) -logger.setLevel("INFO") def train(run_name: str, compute_context: ComputeContext = LocalTorch()): @@ -103,16 +101,7 @@ def train_run( logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " - "Filling stats with last observed values." ) - last_iteration_stats = run.training_stats.iteration_stats[-1] - for i in range( - last_iteration_stats.iteration, latest_weights_iteration - 1 - ): - new_iteration_stats = deepcopy(last_iteration_stats) - new_iteration_stats.iteration = i + 1 - run.training_stats.add_iteration_stats(new_iteration_stats) - trained_until = run.training_stats.trained_until() # start/resume training From c3a81b789bc30e36d9b3d8db05e0e8d7f37b19a4 Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 8 Feb 2024 17:43:40 +0000 Subject: [PATCH 3/3] :art: Format Python code with psf/black --- dacapo/experiments/run.py | 21 ++++++++++++------- dacapo/experiments/tasks/hot_distance_task.py | 3 ++- .../tasks/hot_distance_task_config.py | 3 ++- .../tasks/losses/hot_distance_loss.py | 17 +++++++++------ dacapo/utils/balance_weights.py | 4 ++-- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 1609892c8..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__file__) + class Run: name: str train_until: int @@ -58,28 +59,34 @@ def __init__(self, run_config): return try: from ..store import create_config_store + start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) except Exception as e: - logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) raise e - + # preloaded weights from previous run if run_config.task_config.name == starter_config.task_config.name: self.start = Start(run_config.start_config) else: # Match labels between old and new head - if hasattr(run_config.task_config,"channels"): + if hasattr(run_config.task_config, "channels"): # Map old head and new head old_head = starter_config.task_config.channels new_head = run_config.task_config.channels - self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) else: logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config,remove_head=True) + self.start = Start(run_config.start_config, remove_head=True) self.start.initialize_weights(self.model) - @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index 7f1e4dd96..ef0d03229 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -4,6 +4,7 @@ from .predictors import HotDistancePredictor from .task import Task + class HotDistanceTask(Task): """This is just a Hot Distance Task that combine Binary and distance prediction.""" @@ -21,4 +22,4 @@ def __init__(self, task_config): clip_distance=task_config.clip_distance, tol_distance=task_config.tol_distance, channels=task_config.channels, - ) \ No newline at end of file + ) diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index aab2b01d6..951226476 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -5,6 +5,7 @@ from typing import List + class HotDistanceTaskConfig(TaskConfig): """This is a Hot Distance task config used for generating and evaluating signed distance transforms as a way of generating @@ -43,4 +44,4 @@ class HotDistanceTaskConfig(TaskConfig): "object boundary cannot be known. This is anywhere that the distance to crop boundary " "is less than the distance to object boundary." }, - ) \ No newline at end of file + ) diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 77f34fd08..2e99ab5e1 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -1,6 +1,7 @@ from .loss import Loss import torch + # HotDistance is used for predicting hot and distance maps at the same time. # The first half of the channels are the hot maps, the second half are the distance maps. # The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. @@ -10,15 +11,19 @@ def compute(self, prediction, target, weight): target_hot, target_distance = self.split(target) prediction_hot, prediction_distance = self.split(prediction) weight_hot, weight_distance = self.split(weight) - return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance) - + return self.hot_loss( + prediction_hot, target_hot, weight_hot + ) + self.distance_loss(prediction_distance, target_distance, weight_distance) + def hot_loss(self, prediction, target, weight): return torch.nn.BCELoss().forward(prediction * weight, target * weight) - + def distance_loss(self, prediction, target, weight): return torch.nn.MSELoss().forward(prediction * weight, target * weight) - + def split(self, x): - assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + assert ( + x.shape[0] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." mid = x.shape[0] // 2 - return x[:mid], x[-mid:] \ No newline at end of file + return x[:mid], x[-mid:] diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 949bde0c4..5cd5ee597 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -75,11 +75,11 @@ def balance_weights( scale_slab *= np.take(w, labels_slab) if cross_class: - # get maximum error scale using first dimension + # get maximum error scale using first dimension shape = error_scale.shape error_scale = np.max(error_scale, axis=0) error_scale = np.broadcast_to(error_scale, shape) - + # set error_scale to 0 in masked-out areas for mask in masks: error_scale = error_scale * mask