Skip to content

Commit

Permalink
added online cs scheduler logic
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 committed Aug 16, 2023
1 parent d1c4332 commit 614500d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def get_code_switched_dataset(
infinity_mode=cs_config.get('infinity_mode', False),
sample_rate=config['sample_rate'],
augmentor=augmentor,
schedule=cs_config.get('schedule', None),
)

return dataset
Expand Down
15 changes: 14 additions & 1 deletion nemo/collections/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __init__(
infinity_mode: bool = False,
sample_rate: int = 16000,
augmentor: Optional['AudioAugmentor'] = None,
schedule: Optional[List[int]] = None,
):
super().__init__()

Expand Down Expand Up @@ -373,6 +374,7 @@ def __init__(
self.augmentor = augmentor
self.sample_rate = sample_rate
self.length = 0
self.schedule = schedule
if lang_probs is None:
self.prob_dict = {l: 1.0 / len(self.langs) for l in self.langs}
else:
Expand Down Expand Up @@ -431,6 +433,8 @@ def __init__(
self.collate_fn = self.datasets[self.langs[0]].datasets[0].datasets[0].collate_fn
else:
raise RuntimeError("CodeSwitchedDataset could not locate a valid dataset collate_fn to bind to")

self.global_step = 0

# this method returns an iterator object for a given language ID
# it correctly handles whether the underlying dataset is IterableDataset or mappable
Expand All @@ -455,6 +459,14 @@ def build_single_CS_sample(self):
created_sample_duration_sec = 0
created_sample_langs = []
created_sample_audios = []

if self.schedule and any([x == self.global_step for x in self.schedule]):
idxs = [idx for idx, x in enumerate(self.schedule) if x == self.global_step]
for c in idxs:
self.langs_set.remove(c)
self.lang_probs[c] = 0
self.lang_probs /= self.lang_probs.sum()
self.prob_dict[c] = 0

# if min_monolingual fires, it means we will just return a single, original monolingual utterance
# from one of our languages based on that language's probability
Expand All @@ -466,7 +478,7 @@ def build_single_CS_sample(self):
# synthetic utterance, unless pure_random=True, in which case, you just sample with replacement
# every time
if (self.pure_random and not pure_mono) or (
len(set(created_sample_langs)) == 0 or len(set(created_sample_langs)) == len(self.langs)
len(set(created_sample_langs)) == 0 or len(set(created_sample_langs)) == len(self.langs_set)
):
lang_id = np.random.choice(self.langs, p=self.lang_probs)
# elif pure_mono:
Expand Down Expand Up @@ -597,6 +609,7 @@ def build_single_CS_sample(self):
self.augmentor.perturb(comp_audio_as)
comp_audio = comp_audio_as.samples

self.global_step += 1
return (
torch.tensor(comp_audio, dtype=audio.dtype, device=audio.device),
torch.tensor(len(comp_audio), device=audio_len.device).long(),
Expand Down

0 comments on commit 614500d

Please sign in to comment.