-
Notifications
You must be signed in to change notification settings - Fork 51
/
dataset.py
120 lines (90 loc) · 4.54 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import numpy as np
from scipy.io.wavfile import read as wavread
import warnings
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from torchvision import datasets, models, transforms
import torchaudio
class CleanNoisyPairDataset(Dataset):
"""
Create a Dataset of clean and noisy audio pairs.
Each element is a tuple of the form (clean waveform, noisy waveform, file_id)
"""
def __init__(self, root='./', subset='training', crop_length_sec=0):
super(CleanNoisyPairDataset).__init__()
assert subset is None or subset in ["training", "testing"]
self.crop_length_sec = crop_length_sec
self.subset = subset
N_clean = len(os.listdir(os.path.join(root, 'training_set/clean')))
N_noisy = len(os.listdir(os.path.join(root, 'training_set/noisy')))
assert N_clean == N_noisy
if subset == "training":
self.files = [(os.path.join(root, 'training_set/clean', 'fileid_{}.wav'.format(i)),
os.path.join(root, 'training_set/noisy', 'fileid_{}.wav'.format(i))) for i in range(N_clean)]
elif subset == "testing":
sortkey = lambda name: '_'.join(name.split('_')[-2:]) # specific for dns due to test sample names
_p = os.path.join(root, 'datasets/test_set/synthetic/no_reverb') # path for DNS
clean_files = os.listdir(os.path.join(_p, 'clean'))
noisy_files = os.listdir(os.path.join(_p, 'noisy'))
clean_files.sort(key=sortkey)
noisy_files.sort(key=sortkey)
self.files = []
for _c, _n in zip(clean_files, noisy_files):
assert sortkey(_c) == sortkey(_n)
self.files.append((os.path.join(_p, 'clean', _c),
os.path.join(_p, 'noisy', _n)))
self.crop_length_sec = 0
else:
raise NotImplementedError
def __getitem__(self, n):
fileid = self.files[n]
clean_audio, sample_rate = torchaudio.load(fileid[0])
noisy_audio, sample_rate = torchaudio.load(fileid[1])
clean_audio, noisy_audio = clean_audio.squeeze(0), noisy_audio.squeeze(0)
assert len(clean_audio) == len(noisy_audio)
crop_length = int(self.crop_length_sec * sample_rate)
assert crop_length < len(clean_audio)
# random crop
if self.subset != 'testing' and crop_length > 0:
start = np.random.randint(low=0, high=len(clean_audio) - crop_length + 1)
clean_audio = clean_audio[start:(start + crop_length)]
noisy_audio = noisy_audio[start:(start + crop_length)]
clean_audio, noisy_audio = clean_audio.unsqueeze(0), noisy_audio.unsqueeze(0)
return (clean_audio, noisy_audio, fileid)
def __len__(self):
return len(self.files)
def load_CleanNoisyPairDataset(root, subset, crop_length_sec, batch_size, sample_rate, num_gpus=1):
"""
Get dataloader with distributed sampling
"""
dataset = CleanNoisyPairDataset(root=root, subset=subset, crop_length_sec=crop_length_sec)
kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False}
if num_gpus > 1:
train_sampler = DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, **kwargs)
else:
dataloader = torch.utils.data.DataLoader(dataset, sampler=None, shuffle=True, **kwargs)
return dataloader
if __name__ == '__main__':
import json
with open('./configs/DNS-large-full.json') as f:
data = f.read()
config = json.loads(data)
trainset_config = config["trainset_config"]
trainloader = load_CleanNoisyPairDataset(**trainset_config, subset='training', batch_size=2, num_gpus=1)
testloader = load_CleanNoisyPairDataset(**trainset_config, subset='testing', batch_size=2, num_gpus=1)
print(len(trainloader), len(testloader))
for clean_audio, noisy_audio, fileid in trainloader:
clean_audio = clean_audio.cuda()
noisy_audio = noisy_audio.cuda()
print(clean_audio.shape, noisy_audio.shape, fileid)
break