Skip to content

Commit

Permalink
New sampler with ability to repeat data points
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneNx committed Feb 21, 2022
1 parent 1ff9dc9 commit 6585ed1
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions nntransfer/dataset/img_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import torchvision.transforms.functional as tF
import random
from typing import Sequence
from typing import Sequence, Iterator


class DiscreteRotateTransform:
Expand Down Expand Up @@ -269,7 +269,10 @@ def get_data_loaders(
subset_split = int(np.floor(config.train_subset * len(train_idx)))
train_idx = train_idx[:subset_split]
if config.shuffle:
train_sampler = SubsetRandomSampler(train_idx)
if config.data_repeats:
train_sampler = SubsetRandomSamplerRepeat(train_idx, repeats=config.data_repeats)
else:
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
else:
train_dataset = Subset(train_dataset, train_idx)
Expand Down Expand Up @@ -342,3 +345,23 @@ def get_data_loaders(
shuffle=True,
)
return data_loaders


class SubsetRandomSamplerRepeat(SubsetRandomSampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
generator (Generator): Generator used in sampling.
"""
def __init__(self, indices: Sequence[int], generator=None, repeats: int=1) -> None:
super(SubsetRandomSamplerRepeat, self).__init__(indices, generator)
self.repeats = repeats

def __iter__(self) -> Iterator[int]:
for i in torch.randperm(len(self.indices), generator=self.generator):
for r in range(self.repeats):
yield self.indices[i]

def __len__(self) -> int:
return len(self.indices) * self.repeats

0 comments on commit 6585ed1

Please sign in to comment.