Skip to content

Commit

Permalink
Finalize finetune_sam.py
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 18, 2024
1 parent 5295452 commit 342b61c
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 133 deletions.
31 changes: 17 additions & 14 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,10 @@ def automatic_instance_segmentation(
else:
image_data = util.load_image_data(input_path, key)

if ndim == 3 or image_data.ndim == 3:
if image_data.ndim != 3:
raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'")
if ndim == 2:
assert image_data.ndim == 2 or image_data.shape[-1] == 3, \
f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}"

instances = automatic_3d_segmentation(
volume=image_data,
predictor=predictor,
segmentor=segmenter,
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)
else:
# Precompute the image embeddings.
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
Expand All @@ -137,6 +126,20 @@ def automatic_instance_segmentation(
instances = np.zeros(this_shape, dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
else:
if image_data.ndim != 3:
raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'")

instances = automatic_3d_segmentation(
volume=image_data,
predictor=predictor,
segmentor=segmenter,
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)

if output_path is not None:
# Save the instance segmentation
Expand Down
5 changes: 5 additions & 0 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def __call__(self, x, y):
#


def normalize_to_8bit(raw):
raw = normalize(raw) * 255
return raw


class ResizeRawTrafo:
def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
self.desired_shape = desired_shape
Expand Down
25 changes: 14 additions & 11 deletions workshops/download_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_em.util.image import load_data


def _download_sample_data(path, data_dir, download, url, checksum):
def _download_sample_data(path, data_dir, url, checksum, download):
if os.path.exists(data_dir):
return

Expand All @@ -23,15 +23,15 @@ def _get_cellpose_sample_data_paths(path, download):
url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download"
checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a"

_download_sample_data(path, data_dir, download, url, checksum)
_download_sample_data(path, data_dir, url, checksum, download)

raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png")))
label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png")))

return raw_paths, label_paths


def _get_hpa_data_paths(path, download):
def _get_hpa_data_paths(path, split, download):
urls = [
"https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train
"https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val
Expand All @@ -43,23 +43,26 @@ def _get_hpa_data_paths(path, download):
"8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test
]
splits = ["train", "val", "test"]
assert split in splits, f"'{split}' is not a valid split."

for url, checksum, split in zip(urls, checksums, splits):
data_dir = os.path.join(path, split)
_download_sample_data(path, data_dir, download, url, checksum)
for url, checksum, _split in zip(urls, checksums, splits):
data_dir = os.path.join(path, _split)
_download_sample_data(path, data_dir, url, checksum, download)

# NOTE: For visualization, we choose the train set.
raw_paths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.tif")))
label_paths = natsorted(glob(os.path.join(data_dir, "train", "labels", "*.tif")))
raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif")))

return raw_paths, label_paths
if split == "test": # The 'test' split for HPA does not have labels.
return raw_paths, None
else:
label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif")))
return raw_paths, label_paths


def _get_dataset_paths(path, dataset_name, view=False):
dataset_paths = {
# 2d LM dataset for cell segmentation
"cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True),
"hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True),
"hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"),
# 3d LM dataset for nuclei segmentation
"embedseg": lambda: datasets.embedseg_data.get_embedseg_paths(
path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True,
Expand Down
Loading

0 comments on commit 342b61c

Please sign in to comment.