Skip to content

Commit

Permalink
balance weights:
Browse files Browse the repository at this point in the history
avoid using dacapo specific classes in utils. Utils is a place to put
code that is common and should maybe be provided by a seperate library
  • Loading branch information
pattonw committed Mar 10, 2022
1 parent 70e4b92 commit 8025372
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
18 changes: 11 additions & 7 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

import logging
from typing import List

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,13 +77,16 @@ def create_weight(self, gt, target, mask):
)
else:
distance_mask = np.ones_like(target.data)
return balance_weights(
gt,
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[
NumpyArray.from_np_array(distance_mask, gt.roi, gt.voxel_size, gt.axes)
],
return NumpyArray.from_np_array(
balance_weights(
gt[target.roi],
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[distance_mask],
),
gt.roi,
gt.voxel_size,
gt.axes,
)

@property
Expand Down
40 changes: 18 additions & 22 deletions dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,34 @@
import numpy as np

import itertools
from typing import List


def balance_weights(
labels,
num_classes,
masks=list(),
label_data: np.ndarray,
num_classes: int,
masks: List[np.ndarray] = list(),
slab=None,
clipmin=0.05,
clipmax=0.95,
clipmin: float = 0.05,
clipmax: float = 0.95,
):

label_data = labels[labels.roi]

assert len(np.unique(label_data)) <= num_classes, (
"Found more unique labels than classes in %s." % labels
)
assert 0 <= np.min(label_data) < num_classes, (
"Labels %s are not in [0, num_classes)." % labels
)
assert 0 <= np.max(label_data) < num_classes, (
"Labels %s are not in [0, num_classes)." % labels
)
unique_labels = np.unique(label_data)
assert (
len(unique_labels) <= num_classes
), f"Found unique labels {unique_labels} but expected only {num_classes}."
assert (
0 <= np.min(label_data) < num_classes
), f"Labels {unique_labels} are not in [0, {num_classes})."
assert (
0 <= np.max(label_data) < num_classes
), f"Labels {unique_labels} are not in [0, {num_classes})."

# initialize error scale with 1s
error_scale = np.ones(label_data.shape, dtype=np.float32)

# set error_scale to 0 in masked-out areas
for mask in masks:
error_scale *= mask[labels.roi]
error_scale *= mask

if slab is None:
slab = error_scale.shape
Expand Down Expand Up @@ -71,7 +70,4 @@ def balance_weights(
# scale_slab the masked-in scale_slab with the class weights
scale_slab *= np.take(w, labels_slab)

weights = NumpyArray.from_np_array(
error_scale, labels.roi, labels.voxel_size, labels.axes
)
return weights
return error_scale

0 comments on commit 8025372

Please sign in to comment.