From aabcb765b081b2295add829e253b17ad652141d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 3 Dec 2024 19:41:26 -0500 Subject: [PATCH] Canary <0.5s inference fix via padding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/models/aed_multitask_models.py | 2 ++ .../common/data/lhotse/dataloader.py | 6 ++++ .../common/test_lhotse_dataloading.py | 28 ++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index a4a48a11171f..4399879deb48 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -972,6 +972,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'text_field': config.get('text_field', 'answer'), 'lang_field': config.get('lang_field', 'target_lang'), 'channel_selector': config.get('channel_selector', None), + 'pad_min_duration': config.get('pad_min_duration', 1.0), + 'pad_direction': config.get('pad_direction', 'both'), } temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 98b63a07fa9d..dbcaa294074a 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -135,6 +135,9 @@ class LhotseDataLoadingConfig: rir_enabled: bool = False rir_path: str | None = None # str, must point to a lhotse RecordingSet manifest rir_prob: float = 0.5 + # f. Padding to a minimum duration. Examples shorter than this will be padded, others are unaffected. + pad_min_duration: Optional[float] = None + pad_direction: str = "right" # "right" | "left" | "both" | "random" # 5. Other Lhotse options. text_field: str = "text" # key to read the transcript from @@ -257,6 +260,9 @@ def get_lhotse_dataloader_from_config( keep_excessive_supervisions=config.keep_excessive_supervisions, ) + if config.pad_min_duration is not None: + cuts = cuts.pad(duration=config.pad_min_duration, direction=config.pad_direction, preserve_id=True) + # Duration filtering, same as native NeMo dataloaders. # We can filter after the augmentations because they are applied only when calling load_audio(). cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration)) diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index ec682288cd4c..7322678d048a 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -23,7 +23,7 @@ import torch from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, SupervisionSegment, compute_num_samples from lhotse.audio import AudioLoadingError -from lhotse.cut import Cut, MixedCut +from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.cut.text import TextPairExample from lhotse.testing.dummies import dummy_recording from omegaconf import OmegaConf @@ -311,6 +311,32 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path): # exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts +def test_dataloader_from_lhotse_cuts_pad_min_duration(cutset_path: Path): + config = OmegaConf.create( + { + "cuts_path": cutset_path, + "pad_min_duration": 21.0, + "pad_direction": "left", + "sample_rate": 16000, + "shuffle": True, + "use_lhotse": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + + batch = next(iter(dl)) + (cut,) = batch + assert cut.duration == 21.0 + assert isinstance(cut, MixedCut) + assert len(cut.tracks) == 2 + assert isinstance(cut.tracks[0].cut, PaddingCut) + assert isinstance(cut.tracks[1].cut, MonoCut) + + def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path): # Dataloader without channel selector config = OmegaConf.create(