Skip to content

Commit

Permalink
Revert GunpowderTrainer class and configuration to main
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 11, 2024
1 parent daa41b3 commit c810a0e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 40 deletions.
52 changes: 13 additions & 39 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __init__(self, trainer_config):
self.print_profiling = 100
self.snapshot_iteration = trainer_config.snapshot_interval
self.min_masked = trainer_config.min_masked
self.reject_probability = trainer_config.reject_probability
self.weighted_reject = trainer_config.weighted_reject

self.augments = trainer_config.augments
self.mask_integral_downsample_factor = 4
Expand All @@ -53,14 +51,12 @@ def __init__(self, trainer_config):

def create_optimizer(self, model):
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters())
self.scheduler = (
torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config
optimizer,
start_factor=0.01,
end_factor=1.0,
total_iters=1000,
last_epoch=-1,
)
self.scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
end_factor=1.0,
total_iters=1000,
last_epoch=-1,
)
return optimizer

Expand All @@ -69,9 +65,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
output_shape = Coordinate(model.output_shape)

# get voxel sizes
raw_voxel_size = datasets[
0
].raw.voxel_size # TODO: make dataset specific / resample
raw_voxel_size = datasets[0].raw.voxel_size
prediction_voxel_size = model.scale(raw_voxel_size)

# define input and output size:
Expand All @@ -91,15 +85,15 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER")

target_key = gp.ArrayKey("TARGET")
dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") # TODO: put these back in
dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT")
datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT")
weight_key = gp.ArrayKey("WEIGHT")
sample_points_key = gp.GraphKey("SAMPLE_POINTS")

# Get source nodes
dataset_sources = []
weights = []
for dataset in datasets: # TODO: add automatic resampling?
for dataset in datasets:
weights.append(dataset.weight)
assert isinstance(dataset.weight, int), dataset

Expand Down Expand Up @@ -152,30 +146,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
)
)

if self.weighted_reject:
# Add predictor nodes to dataset_source
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)
dataset_source += gp.Reject(mask_placeholder, 1e-6)

dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=weight_key,
target_key=target_key,
mask_key=mask_key,
)

dataset_source += gp.Reject(
mask=weight_key,
min_masked=self.min_masked,
reject_probability=self.reject_probability,
)
else:
dataset_source += gp.Reject(
mask=mask_placeholder,
min_masked=self.min_masked,
reject_probability=self.reject_probability,
)
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
Expand Down Expand Up @@ -290,7 +264,7 @@ def iterate(self, num_iterations, model, optimizer, device):
}
if mask is not None:
snapshot_arrays["volumes/mask"] = mask
logger.info(
logger.warning(
f"Saving Snapshot. Iteration: {iteration}, "
f"Loss: {loss.detach().cpu().numpy().item()}!"
)
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GunpowderTrainerConfig(TrainerConfig):
metadata={"help_text": "Number of iterations before saving a new snapshot."},
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=False)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
Expand Down

0 comments on commit c810a0e

Please sign in to comment.