Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/balancing #70

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
affs_weight_clipmin=task_config.affs_weight_clipmin,
affs_weight_clipmax=task_config.affs_weight_clipmax,
lsd_weight_clipmin=task_config.lsd_weight_clipmin,
lsd_weight_clipmax=task_config.lsd_weight_clipmax,
)
self.loss = AffinitiesLoss(
len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio
Expand Down
16 changes: 16 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,19 @@ class AffinitiesTaskConfig(TaskConfig):
"help_text": "If training with lsds, set how much they should be weighted compared to affs."
},
)
affs_weight_clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for affinities weights."},
)
affs_weight_clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for affinities weights."},
)
lsd_weight_clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for lsds weights."},
)
lsd_weight_clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for lsds weights."},
)
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, task_config):
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
clipmin=task_config.clipmin,
clipmax=task_config.clipmax,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
8 changes: 8 additions & 0 deletions dacapo/experiments/tasks/distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ class DistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)
clipmin: float = attr.ib(
default=0.05,
metadata={"help_text": "The minimum value for distance weights."},
)
clipmax: float = attr.ib(
default=0.95,
metadata={"help_text": "The maximum value for distance weights."},
)
12 changes: 12 additions & 0 deletions dacapo/experiments/tasks/predictors/affinities_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(
num_voxels: int = 20,
downsample_lsds: int = 1,
grow_boundary_iterations: int = 0,
affs_weight_clipmin: float = 0.05,
affs_weight_clipmax: float = 0.95,
lsd_weight_clipmin: float = 0.05,
lsd_weight_clipmax: float = 0.95,
):
self.neighborhood = neighborhood
self.lsds = lsds
Expand All @@ -42,6 +46,10 @@ def __init__(
else:
self.num_lsds = 0
self.grow_boundary_iterations = grow_boundary_iterations
self.affs_weight_clipmin = affs_weight_clipmin
self.affs_weight_clipmax = affs_weight_clipmax
self.lsd_weight_clipmin = lsd_weight_clipmin
self.lsd_weight_clipmax = lsd_weight_clipmax

def extractor(self, voxel_size):
if self._extractor is None:
Expand Down Expand Up @@ -155,6 +163,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=tuple(1 if c == "c" else -1 for c in target.axes),
masks=[mask_data],
moving_counts=moving_class_counts,
clipmin=self.affs_weight_clipmin,
clipmax=self.affs_weight_clipmax,
)
if self.lsds:
lsd_weights, moving_lsd_class_counts = balance_weights(
Expand All @@ -163,6 +173,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=(-1,) * len(gt.axes),
masks=[mask_data],
moving_counts=moving_lsd_class_counts,
clipmin=self.lsd_weight_clipmin,
clipmax=self.lsd_weight_clipmax,
)
lsd_weights = np.ones(
(self.num_lsds,) + aff_weights.shape[1:], dtype=aff_weights.dtype
Expand Down
13 changes: 12 additions & 1 deletion dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
clipmin: float = 0.05,
clipmax: float = 0.95,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -36,6 +43,8 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.clipmin = clipmin
self.clipmax = clipmax

@property
def embedding_dims(self):
Expand Down Expand Up @@ -83,6 +92,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi], distance_mask],
moving_counts=moving_class_counts,
clipmin=self.clipmin,
clipmax=self.clipmax,
)
return (
NumpyArray.from_np_array(
Expand Down
43 changes: 12 additions & 31 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)

self.scheduler = None

def create_optimizer(self, model):
Expand Down Expand Up @@ -85,8 +80,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER")

target_key = gp.ArrayKey("TARGET")
dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT")
datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT")
weight_key = gp.ArrayKey("WEIGHT")
sample_points_key = gp.GraphKey("SAMPLE_POINTS")

Expand Down Expand Up @@ -137,12 +130,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
+ gp.Pad(gt_key, None)
+ gp.Pad(mask_key, None)
+ gp.RandomLocation(
ensure_nonempty=sample_points_key
if points_source is not None
else None,
ensure_centered=sample_points_key
if points_source is not None
else None,
ensure_nonempty=(
sample_points_key if points_source is not None else None
),
ensure_centered=(
sample_points_key if points_source is not None else None
),
)
)

Expand All @@ -151,15 +144,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)

Expand All @@ -168,15 +152,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
weights_key=weight_key,
mask_key=mask_key,
)

if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
pipeline += gp.PreCache(num_workers=self.num_data_fetchers)
Expand Down Expand Up @@ -332,9 +311,11 @@ def next(self):
NumpyArray.from_gp_array(batch[self._gt_key]),
NumpyArray.from_gp_array(batch[self._target_key]),
NumpyArray.from_gp_array(batch[self._weight_key]),
NumpyArray.from_gp_array(batch[self._mask_key])
if self._mask_key is not None
else None,
(
NumpyArray.from_gp_array(batch[self._mask_key])
if self._mask_key is not None
else None
),
)

def __enter__(self):
Expand Down
7 changes: 0 additions & 7 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,3 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)
Loading