Skip to content

Commit

Permalink
Remove merge conflict artifacts (computational-cell-analytics#774)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored and psobolewskiPhD committed Nov 29, 2024
1 parent 652900f commit 157d97f
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def mask_data_to_segmentation(
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
max_object_size: The maximal size of an object in pixels.
label_masks: Whether to apply connected components to the result before removing small objects.
label_masks: Whether to apply connected components to the result before remving small objects.
Returns:
The instance segmentation.
Expand All @@ -85,8 +85,7 @@ def require_numpy(mask):
seg_id = this_seg_id + 1

if label_masks:
segmentation = label(segmentation).astype(segmentation.dtype)

segmentation = label(segmentation)
seg_ids, sizes = np.unique(segmentation, return_counts=True)

# In some cases objects may be smaller than peviously calculated,
Expand Down Expand Up @@ -216,7 +215,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):

# recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels.
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores, dtype=torch.float),
Expand Down Expand Up @@ -1123,7 +1122,10 @@ def initialize(


def get_amg(
predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
predictor: SamPredictor,
is_tiled: bool,
decoder: Optional[torch.nn.Module] = None,
**kwargs,
) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
"""Get the automatic mask generator class.
Expand All @@ -1137,10 +1139,9 @@ def get_amg(
The automatic mask generator.
"""
if decoder is None:
segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
segmenter = segmenter_class(predictor, **kwargs)
segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\
AutomaticMaskGenerator(predictor, **kwargs)
else:
segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
segmenter = segmenter_class(predictor, decoder, **kwargs)

segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\
InstanceSegmentationWithDecoder(predictor, decoder, **kwargs)
return segmenter

0 comments on commit 157d97f

Please sign in to comment.