From 8025372642884dd4816ae772c6a8243bca4204fd Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 10 Mar 2022 16:46:29 -0500 Subject: [PATCH] balance weights: avoid using dacapo specific classes in utils. Utils is a place to put code that is common and should maybe be provided by a seperate library --- .../tasks/predictors/distance_predictor.py | 18 +++++---- dacapo/utils/balance_weights.py | 40 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index f18ea37f9..bbcf3c6cc 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -12,6 +12,7 @@ import torch import logging +from typing import List logger = logging.getLogger(__name__) @@ -76,13 +77,16 @@ def create_weight(self, gt, target, mask): ) else: distance_mask = np.ones_like(target.data) - return balance_weights( - gt, - 2, - slab=tuple(1 if c == "c" else -1 for c in gt.axes), - masks=[ - NumpyArray.from_np_array(distance_mask, gt.roi, gt.voxel_size, gt.axes) - ], + return NumpyArray.from_np_array( + balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[distance_mask], + ), + gt.roi, + gt.voxel_size, + gt.axes, ) @property diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 7ead93d5a..970abf42b 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -3,35 +3,34 @@ import numpy as np import itertools +from typing import List def balance_weights( - labels, - num_classes, - masks=list(), + label_data: np.ndarray, + num_classes: int, + masks: List[np.ndarray] = list(), slab=None, - clipmin=0.05, - clipmax=0.95, + clipmin: float = 0.05, + clipmax: float = 0.95, ): - - label_data = labels[labels.roi] - - assert len(np.unique(label_data)) <= num_classes, ( - "Found more unique labels than classes in %s." % labels - ) - assert 0 <= np.min(label_data) < num_classes, ( - "Labels %s are not in [0, num_classes)." % labels - ) - assert 0 <= np.max(label_data) < num_classes, ( - "Labels %s are not in [0, num_classes)." % labels - ) + unique_labels = np.unique(label_data) + assert ( + len(unique_labels) <= num_classes + ), f"Found unique labels {unique_labels} but expected only {num_classes}." + assert ( + 0 <= np.min(label_data) < num_classes + ), f"Labels {unique_labels} are not in [0, {num_classes})." + assert ( + 0 <= np.max(label_data) < num_classes + ), f"Labels {unique_labels} are not in [0, {num_classes})." # initialize error scale with 1s error_scale = np.ones(label_data.shape, dtype=np.float32) # set error_scale to 0 in masked-out areas for mask in masks: - error_scale *= mask[labels.roi] + error_scale *= mask if slab is None: slab = error_scale.shape @@ -71,7 +70,4 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) - weights = NumpyArray.from_np_array( - error_scale, labels.roi, labels.voxel_size, labels.axes - ) - return weights + return error_scale