Skip to content

Commit

Permalink
Expose is_seg_dataset argument in sam dataset (#736)
Browse files Browse the repository at this point in the history
Expose is_seg_dataset argument in sam dataset
  • Loading branch information
anwai98 authored Oct 15, 2024
1 parent 766aa9b commit cd4418a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def default_sam_dataset(
is_train: bool = True,
min_size: int = 25,
max_sampling_attempts: Optional[int] = None,
is_seg_dataset: Optional[bool] = None,
**kwargs,
) -> Dataset:
"""Create a PyTorch Dataset for training a SAM model.
Expand All @@ -412,6 +413,8 @@ def default_sam_dataset(
is_train: Whether this dataset is used for training or validation.
min_size: Minimal object size. Smaller objects will be filtered.
max_sampling_attempts: Number of sampling attempts to make from a dataset.
is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
or 'from torch_em.data import ImageCollectionDataset'
Returns:
The dataset.
Expand Down Expand Up @@ -443,8 +446,8 @@ def default_sam_dataset(
# Set a minimum number of samples per epoch.
if n_samples is None:
loader = torch_em.default_segmentation_loader(
raw_paths, raw_key, label_paths, label_key,
batch_size=1, patch_shape=patch_shape, ndim=2
raw_paths, raw_key, label_paths, label_key, batch_size=1,
patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
)
n_samples = max(len(loader), 100 if is_train else 5)

Expand All @@ -454,6 +457,7 @@ def default_sam_dataset(
raw_transform=raw_transform, label_transform=label_transform,
with_channels=with_channels, ndim=2,
sampler=sampler, n_samples=n_samples,
is_seg_dataset=is_seg_dataset,
**kwargs,
)

Expand Down

0 comments on commit cd4418a

Please sign in to comment.