Skip to content

Commit

Permalink
Canary <0.5s inference fix via padding
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Dec 4, 2024
1 parent 8464f2e commit aabcb76
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
'channel_selector': config.get('channel_selector', None),
'pad_min_duration': config.get('pad_min_duration', 1.0),
'pad_direction': config.get('pad_direction', 'both'),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class LhotseDataLoadingConfig:
rir_enabled: bool = False
rir_path: str | None = None # str, must point to a lhotse RecordingSet manifest
rir_prob: float = 0.5
# f. Padding to a minimum duration. Examples shorter than this will be padded, others are unaffected.
pad_min_duration: Optional[float] = None
pad_direction: str = "right" # "right" | "left" | "both" | "random"

# 5. Other Lhotse options.
text_field: str = "text" # key to read the transcript from
Expand Down Expand Up @@ -257,6 +260,9 @@ def get_lhotse_dataloader_from_config(
keep_excessive_supervisions=config.keep_excessive_supervisions,
)

if config.pad_min_duration is not None:
cuts = cuts.pad(duration=config.pad_min_duration, direction=config.pad_direction, preserve_id=True)

# Duration filtering, same as native NeMo dataloaders.
# We can filter after the augmentations because they are applied only when calling load_audio().
cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration))
Expand Down
28 changes: 27 additions & 1 deletion tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, SupervisionSegment, compute_num_samples
from lhotse.audio import AudioLoadingError
from lhotse.cut import Cut, MixedCut
from lhotse.cut import Cut, MixedCut, PaddingCut
from lhotse.cut.text import TextPairExample
from lhotse.testing.dummies import dummy_recording
from omegaconf import OmegaConf
Expand Down Expand Up @@ -311,6 +311,32 @@ def test_dataloader_from_lhotse_cuts_cut_into_windows(cutset_path: Path):
# exactly 20 cuts were used because we cut 10x 1s cuts into 20x 0.5s cuts


def test_dataloader_from_lhotse_cuts_pad_min_duration(cutset_path: Path):
config = OmegaConf.create(
{
"cuts_path": cutset_path,
"pad_min_duration": 21.0,
"pad_direction": "left",
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 1,
"seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity())

batch = next(iter(dl))
(cut,) = batch
assert cut.duration == 21.0
assert isinstance(cut, MixedCut)
assert len(cut.tracks) == 2
assert isinstance(cut.tracks[0].cut, PaddingCut)
assert isinstance(cut.tracks[1].cut, MonoCut)


def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path):
# Dataloader without channel selector
config = OmegaConf.create(
Expand Down

0 comments on commit aabcb76

Please sign in to comment.