diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 72427951e..f5d8fcd52 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -37,8 +37,6 @@ def __init__(self, trainer_config): self.print_profiling = 100 self.snapshot_iteration = trainer_config.snapshot_interval self.min_masked = trainer_config.min_masked - self.reject_probability = trainer_config.reject_probability - self.weighted_reject = trainer_config.weighted_reject self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 @@ -53,14 +51,12 @@ def __init__(self, trainer_config): def create_optimizer(self, model): optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) - self.scheduler = ( - torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config - optimizer, - start_factor=0.01, - end_factor=1.0, - total_iters=1000, - last_epoch=-1, - ) + self.scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=1000, + last_epoch=-1, ) return optimizer @@ -69,9 +65,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): output_shape = Coordinate(model.output_shape) # get voxel sizes - raw_voxel_size = datasets[ - 0 - ].raw.voxel_size # TODO: make dataset specific / resample + raw_voxel_size = datasets[0].raw.voxel_size prediction_voxel_size = model.scale(raw_voxel_size) # define input and output size: @@ -91,7 +85,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") # TODO: put these back in + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -99,7 +93,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): # Get source nodes dataset_sources = [] weights = [] - for dataset in datasets: # TODO: add automatic resampling? + for dataset in datasets: weights.append(dataset.weight) assert isinstance(dataset.weight, int), dataset @@ -152,30 +146,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) ) - if self.weighted_reject: - # Add predictor nodes to dataset_source - for augment in self.augments: - dataset_source += augment.node(raw_key, gt_key, mask_key) + dataset_source += gp.Reject(mask_placeholder, 1e-6) - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=weight_key, - target_key=target_key, - mask_key=mask_key, - ) - - dataset_source += gp.Reject( - mask=weight_key, - min_masked=self.min_masked, - reject_probability=self.reject_probability, - ) - else: - dataset_source += gp.Reject( - mask=mask_placeholder, - min_masked=self.min_masked, - reject_probability=self.reject_probability, - ) + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) if self.add_predictor_nodes_to_dataset: # Add predictor nodes to dataset_source @@ -290,7 +264,7 @@ def iterate(self, num_iterations, model, optimizer, device): } if mask is not None: snapshot_arrays["volumes/mask"] = mask - logger.info( + logger.warning( f"Saving Snapshot. Iteration: {iteration}, " f"Loss: {loss.detach().cpu().numpy().item()}!" ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 255c73ad6..539e3c5e1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -28,7 +28,7 @@ class GunpowderTrainerConfig(TrainerConfig): metadata={"help_text": "Number of iterations before saving a new snapshot."}, ) min_masked: Optional[float] = attr.ib(default=0.15) - clip_raw: bool = attr.ib(default=False) + clip_raw: bool = attr.ib(default=True) add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( default=True,