diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index ca52d742..c03c7a4f 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -59,7 +59,7 @@ def mask_data_to_segmentation( object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. max_object_size: The maximal size of an object in pixels. - label_masks: Whether to apply connected components to the result before removing small objects. + label_masks: Whether to apply connected components to the result before remving small objects. Returns: The instance segmentation. @@ -85,8 +85,7 @@ def require_numpy(mask): seg_id = this_seg_id + 1 if label_masks: - segmentation = label(segmentation).astype(segmentation.dtype) - + segmentation = label(segmentation) seg_ids, sizes = np.unique(segmentation, return_counts=True) # In some cases objects may be smaller than peviously calculated, @@ -216,7 +215,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh): # recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) - boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels. + boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores, dtype=torch.float), @@ -1123,7 +1122,10 @@ def initialize( def get_amg( - predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs, + predictor: SamPredictor, + is_tiled: bool, + decoder: Optional[torch.nn.Module] = None, + **kwargs, ) -> Union[AMGBase, InstanceSegmentationWithDecoder]: """Get the automatic mask generator class. @@ -1137,10 +1139,9 @@ def get_amg( The automatic mask generator. """ if decoder is None: - segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator - segmenter = segmenter_class(predictor, **kwargs) + segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\ + AutomaticMaskGenerator(predictor, **kwargs) else: - segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder - segmenter = segmenter_class(predictor, decoder, **kwargs) - + segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\ + InstanceSegmentationWithDecoder(predictor, decoder, **kwargs) return segmenter