Getting worser enhancement results from ConvTasNet, DCCRNet on DNS data #461
Unanswered
raikarsagar
asked this question in
Q&A
Replies: 2 comments 4 replies
-
We'll need more detail than that. |
Beta Was this translation helpful? Give feedback.
2 replies
-
diff --git a/asteroid/data/dns_dataset_wideband_norvb.py b/asteroid/data/dns_dataset_wideband_norvb.py
index 693b739..cb2b337 100644
--- a/asteroid/data/dns_dataset_wideband_norvb.py
+++ b/asteroid/data/dns_dataset_wideband_norvb.py
@@ -18,7 +18,7 @@ class DNSDataset2(data.Dataset):
dataset_name = "DNS"
- def __init__(self, json_dir, task, sample_rate, n_src, segment):
+ def __init__(self, json_dir, task, sample_rate=16000, n_src=1, segment=4):
super(DNSDataset2, self).__init__()
self.json_dir = json_dir
@@ -26,10 +26,29 @@ class DNSDataset2(data.Dataset):
self.sample_rate = sample_rate
self.n_src = n_src
self.segment = segment
+ self.seg_len = None if self.segment is None else int(self.segment * self.sample_rate)
+ self.like_test = self.seg_len is None
with open(os.path.join(json_dir, "file_infos.json"), "r") as f:
self.mix_infos = json.load(f)
+ # Filter out short utterances only when segment is specified
+ #orig_len = len(self.mix_infos)
+ #print(orig_len)
+ #drop_utt, drop_len = 0, 0
+ #if not self.like_test:
+ # print(len(self.mix_infos))
+ # for i in range(len(self.mix_infos) - 1, -1, -1): # Go backward
+ # print(type(self.mix_infos[i]["file_len"]))
+ # print(type(self.seg_len))
+ # if self.mix_infos[i]["file_len"] < self.seg_len:
+ # drop_utt += 1
+ # drop_len += self.mix_infos[i]["file_len"]
+ # del self.mix_infos[i]
+
+ # print("Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ # drop_utt, drop_len/sample_rate/36000, orig_len, self.seg_len))
+
self.wav_ids = list(self.mix_infos.keys())
def __len__(self):
@@ -40,25 +59,40 @@ class DNSDataset2(data.Dataset):
Returns:
mixture, vstack([source_arrays])
"""
+ # Random start
+ #print(idx)
+ #print(self.mix_infos[self.wav_ids[idx]])
+ # print(self.seg_len)
+ if self.mix_infos[self.wav_ids[idx]]["file_len"] == self.seg_len or self.like_test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix_infos[self.wav_ids[idx]]["file_len"] - self.seg_len)
+ if self.like_test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # print(rand_start, stop)
utt_info = self.mix_infos[self.wav_ids[idx]]
+ # print(utt_info)
+
sources_list = []
# Load mixture
- x = torch.from_numpy(sf.read(utt_info["mix"], dtype="float32")[0])
+ x = torch.from_numpy(sf.read(utt_info["mix"], start=rand_start, stop=stop, dtype="float32")[0])
- s = sf.read(utt_info["clean"], dtype="float32")[0]
+ s = sf.read(utt_info["clean"], start=rand_start, stop=stop, dtype="float32")[0]
sources_list.append(s)
-
+ # print(sources_list)
# load clean without reverberation
speech = torch.from_numpy(np.vstack(sources_list))
# speech = sf.read(utt_info["clean_norvb"], dtype="float32")[0]
# Load clean
- speech_no_rvb = torch.from_numpy(sf.read(utt_info["clean_norvb"], dtype="float32")[0])
+ speech_no_rvb = torch.from_numpy(sf.read(utt_info["clean_norvb"], start=rand_start, stop=stop, dtype="float32")[0])
# Load noise
- noise = torch.from_numpy(sf.read(utt_info["noise"], dtype="float32")[0])
-
+ noise = torch.from_numpy(sf.read(utt_info["noise"], start=rand_start, stop=stop, dtype="float32")[0])
+ # print(x, speech, speech_no_rvb, noise)
return x, speech, speech_no_rvb, noise
def get_infos(self):
diff --git a/asteroid/engine/system_dns_wideband.py b/asteroid/engine/system_dns_wideband.py
index 3a2c6fd..80400ae 100644
--- a/asteroid/engine/system_dns_wideband.py
+++ b/asteroid/engine/system_dns_wideband.py
@@ -100,7 +100,9 @@ class System(pl.LightningModule):
the argument ``train`` can be used to switch behavior.
Otherwise, ``training_step`` and ``validation_step`` can be overwriten.
"""
+ # print(batch)
inputs, targets, _, _ = batch
+ # print(inputs, targets)
est_targets = self(inputs)
loss = self.loss_func(est_targets, targets)
return loss
diff --git a/egs/dns_challenge/ConvTasNet/eval_synthetic_3.py b/egs/dns_challenge/ConvTasNet/eval_synthetic_3.py
index 5b12d1d..1f55ac6 100644
--- a/egs/dns_challenge/ConvTasNet/eval_synthetic_3.py
+++ b/egs/dns_challenge/ConvTasNet/eval_synthetic_3.py
@@ -82,7 +82,7 @@ def get_wavs_dict_list(test_dir):
'id': 3}
"""
# Find all clean files and make an {id: filepath} dictionary
- clean_wavs = glob.glob(os.path.join(test_dir, "clean_norvb/*.wav"))
+ clean_wavs = glob.glob(os.path.join(test_dir, "clean/*.wav"))
clean_dic = make_wav_id_dict(clean_wavs)
# Same for noisy files
noisy_wavs = glob.glob(os.path.join(test_dir, "noisy/*.wav"))
diff --git a/egs/dns_challenge/ConvTasNet/local/conf.yml b/egs/dns_challenge/ConvTasNet/local/conf.yml
index 05b2a17..ef0d7ea 100644
--- a/egs/dns_challenge/ConvTasNet/local/conf.yml
+++ b/egs/dns_challenge/ConvTasNet/local/conf.yml
@@ -1,6 +1,6 @@
# filterbank config
filterbank:
- n_filters: 512
+ n_filters: 256
kernel_size: 16
stride: 8
# Network config
@@ -10,7 +10,7 @@ masknet:
mask_act: relu
bn_chan: 128
skip_chan: 128
- hid_chan: 512
+ hid_chan: 256
# Training config
training:
epochs: 200
diff --git a/egs/dns_challenge/ConvTasNet/local/conf_dns_wideband.yml b/egs/dns_challenge/ConvTasNet/local/conf_dns_wideband.yml
index 5c5fd65..f218dbb 100644
--- a/egs/dns_challenge/ConvTasNet/local/conf_dns_wideband.yml
+++ b/egs/dns_challenge/ConvTasNet/local/conf_dns_wideband.yml
@@ -1,16 +1,16 @@
# filterbank config
filterbank:
n_filters: 256
- kernel_size: 8
- stride: 4
+ kernel_size: 16
+ stride: 8
# Network config
masknet:
- n_blocks: 5
+ n_blocks: 8
n_repeats: 3
mask_act: relu
- bn_chan: 64
- skip_chan: 64
- hid_chan: 128
+ bn_chan: 128
+ skip_chan: 128
+ hid_chan: 512
# Training config
training:
epochs: 5
@@ -30,6 +30,6 @@ data:
# valid_dir: data/wav8k/min/dev
sample_rate: 16000
n_src: 1
- segment: 1
- json_dir: /home/stuart/sagar/interspeech21_dns/DNS-Challenge/datasets_wideband/datasets/training_set_feb22
+ segment: 4
+ json_dir: /home/ubuntu/data_se/training_set_feb22
val_prop: 0.1
diff --git a/egs/dns_challenge/ConvTasNet/run_wideband.sh b/egs/dns_challenge/ConvTasNet/run_wideband.sh
index 6e84ef2..eb9deab 100755
--- a/egs/dns_challenge/ConvTasNet/run_wideband.sh
+++ b/egs/dns_challenge/ConvTasNet/run_wideband.sh
@@ -30,8 +30,8 @@ n_blocks=5
n_repeats=3
mask_act=relu
# Training config
-epochs=2
-batch_size=1
+epochs=1
+batch_size=7
num_workers=1
half_lr=yes
early_stop=yes
@@ -43,10 +43,10 @@ weight_decay=0.
sample_rate=16000
mode=min
n_src=1
-segment=2
+segment=3
task=enh_both # one of 'enh_single', 'enh_both', 'sep_clean', 'sep_noisy'
-eval_use_gpu=1
+eval_use_gpu=0
# Need to --compute_wer 1 --eval_mode max to be sure the user knows all the metrics
# are for the all mode.
compute_wer=0
@@ -65,8 +65,9 @@ fi
# train_dir=data/$suffix/train-360
# valid_dir=data/$suffix/dev
# test_dir=data/wav${sr_string}k/$eval_mode/test
+test_dir=/home/ubuntu/data_se/test_set_feb22
-dumpdir=/home/stuart/sagar/interspeech21_dns/DNS-Challenge/datasets_wideband/datasets/training_set_feb22
+dumpdir=/home/ubuntu/data_se/training_set_feb22
if [[ $stage -le 0 ]]; then
echo "Stage 0: Generating Librimix dataset"
@@ -152,3 +153,7 @@ fi
# bash run_wideband.sh --tag dns_wideband_2 --id 0
# bash run_wideband.sh --tag dns_wideband_3 --id 0
+# segment length issue is resolved
+# bash run_wideband.sh --tag dns_wideband_4 --id 0
+
+# bash run_wideband.sh --tag dns_wideband_5 --id 0
diff --git a/egs/dns_challenge/ConvTasNet/run_wideband_resume.sh b/egs/dns_challenge/ConvTasNet/run_wideband_resume.sh
index 3343154..211b9e4 100755
--- a/egs/dns_challenge/ConvTasNet/run_wideband_resume.sh
+++ b/egs/dns_challenge/ConvTasNet/run_wideband_resume.sh
@@ -30,7 +30,7 @@ n_blocks=5
n_repeats=3
mask_act=relu
# Training config
-epochs=3
+epochs=5
batch_size=1
num_workers=1
half_lr=yes
@@ -44,7 +44,7 @@ sample_rate=16000
mode=min
n_src=1
segment=2
-task=sep_clean # one of 'enh_single', 'enh_both', 'sep_clean', 'sep_noisy'
+task=enh_both # one of 'enh_single', 'enh_both', 'sep_clean', 'sep_noisy'
eval_use_gpu=0
# Need to --compute_wer 1 --eval_mode max to be sure the user knows all the metrics
@@ -65,9 +65,9 @@ fi
# train_dir=data/$suffix/train-360
# valid_dir=data/$suffix/dev
# test_dir=data/wav${sr_string}k/$eval_mode/test
-test_dir=/home/stuart/sagar/interspeech21_dns/DNS-Challenge/datasets_wideband/datasets/test_set_feb22
+test_dir=/home/ubuntu/data_se/test_set_feb22
-dumpdir=/home/stuart/sagar/interspeech21_dns/DNS-Challenge/datasets_wideband/datasets/training_set_feb22
+dumpdir=/home/ubuntu/data_se/training_set_feb22
if [[ $stage -le 0 ]]; then
echo "Stage 0: Generating Librimix dataset"
diff --git a/egs/dns_challenge/ConvTasNet/train_wideband_resume.py b/egs/dns_challenge/ConvTasNet/train_wideband_resume.py
index d8acc11..708ac2e 100644
--- a/egs/dns_challenge/ConvTasNet/train_wideband_resume.py
+++ b/egs/dns_challenge/ConvTasNet/train_wideband_resume.py
@@ -37,8 +37,9 @@ def main(conf):
train_len = int(len(total_set) * (1 - conf["data"]["val_prop"]))
val_len = len(total_set) - train_len
+ print("train len and val len: ", train_len, val_len)
train_set, val_set = random_split(total_set, [train_len, val_len])
-
+ print("train set and val set: ", train_set, val_set)
# train_set = LibriMix(
# csv_dir=conf["data"]["train_dir"],
@@ -118,7 +119,7 @@ def main(conf):
max_epochs=conf["training"]["epochs"],
callbacks=callbacks,
default_root_dir=exp_dir,
- resume_from_checkpoint=os.path.join(checkpoint_dir,'epoch=2-step=86609.ckpt'),
+ resume_from_checkpoint=os.path.join(checkpoint_dir,'epoch=1-step=86399.ckpt'),
gpus=gpus,
distributed_backend=distributed_backend,
limit_train_batches=1.0, # Useful for fast experiment
@@ -129,8 +130,9 @@ def main(conf):
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
-
- state_dict = torch.load(checkpoint.best_model_path)
+ # print("checkpoint.best_model_path: ", checkpoint.best_model_path)
+ # state_dict = torch.load(checkpoint.best_model_path)
+ state_dict = torch.load("/home/ubuntu/interspeech_21/dns-challenge/egs/dns_challenge/ConvTasNet/exp/train_convtasnet_dns_wideband_3/checkpoints/epoch=1-step=86399.ckpt")
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu() |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have trained Convtasnet and DCCRnet models using the librimix and wham recipes for DNS data with custom data loaders. But the eval results are too much deviating from reported ones.
one such case:
Data: DNS wideband dataset
ConvTasNet - wham recipe with wham dataloader
Overall metrics :
{'sar': -8.774382063462335,
'sar_imp': -23.122993536855084,
'sdr': -8.774382063462335,
'sdr_imp': -23.122993536855084,
'si_sdr': -13.292219625200545,
'si_sdr_imp': -27.630510232650806,
'sir': inf,
'sir_imp': nan,
'stoi': 0.5870689938616673,
'stoi_imp': -0.27482733389575525}
Beta Was this translation helpful? Give feedback.
All reactions