Skip to content

Commit

Permalink
Merge pull request #266 from Visual-Behavior/alobugdays_issue107
Browse files Browse the repository at this point in the history
Alobugdays issue107
  • Loading branch information
jsalotti authored Nov 17, 2022
2 parents 1358b57 + 12c3a87 commit 121b3f2
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions alodataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def stream_loader(dataset, num_workers=2):
return data_loader


def train_loader(dataset, batch_size=1, num_workers=2, sampler=torch.utils.data.RandomSampler):
def train_loader(dataset, batch_size=1, num_workers=2, sampler=torch.utils.data.RandomSampler, sampler_kwargs={}):
"""Get training loader from the dataset
Parameters
Expand All @@ -69,14 +69,15 @@ def train_loader(dataset, batch_size=1, num_workers=2, sampler=torch.utils.data.
Number of workers, by default 2
sampler : torch.utils.data, optional
Callback to sampler the dataset, by default torch.utils.data.RandomSampler
Or instance of any class inheriting from torch.utils.data.Sampler
Returns
-------
torch.utils.data.DataLoader
A generator
"""
sampler = sampler(dataset) if sampler is not None else None

if sampler is not None and not(isinstance(sampler, torch.utils.data.Sampler)):
sampler = sampler(dataset, **sampler_kwargs)
data_loader = torch.utils.data.DataLoader(
dataset,
# batch_sampler=batch_sampler,
Expand Down Expand Up @@ -332,7 +333,7 @@ def stream_loader(self, num_workers=2):
"""
return stream_loader(self, num_workers=num_workers)

def train_loader(self, batch_size=1, num_workers=2, sampler=torch.utils.data.RandomSampler):
def train_loader(self, batch_size=1, num_workers=2, sampler=torch.utils.data.RandomSampler, sampler_kwargs={}):
"""Get training loader from the dataset
Parameters
Expand All @@ -351,7 +352,7 @@ def train_loader(self, batch_size=1, num_workers=2, sampler=torch.utils.data.Ran
torch.utils.data.DataLoader
A generator
"""
return train_loader(self, batch_size=batch_size, num_workers=num_workers, sampler=sampler)
return train_loader(self, batch_size=batch_size, num_workers=num_workers, sampler=sampler, sampler_kwargs=sampler_kwargs )

def prepare(self):
"""Prepare the dataset. Not all child class need to implement this method.
Expand Down

0 comments on commit 121b3f2

Please sign in to comment.