Skip to content

Commit

Permalink
Add Lhotse issue workaround
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Dec 4, 2024
1 parent aabcb76 commit 4b1bfce
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

Expand Down Expand Up @@ -99,7 +100,7 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
cuts=cuts.drop_in_memory_data(),
cuts=_drop_in_memory_data(cuts),
)

def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -111,3 +112,24 @@ def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple

class ProbablyIncorrectLanguageKeyError(RuntimeError):
pass


def _drop_in_memory_data(
cuts: CutSet,
_fields=frozenset(MixedCut.__dataclass_fields__.keys()),
) -> CutSet:
"""Workaround for an edge case in cuts.drop_in_memory_data() on MixedCut with Lhotse<1.29.0"""
ans = []
for c in cuts:
# Not a mixed cut or a mixed cut that wasn't assigned any extra attributes.
if not isinstance(c, MixedCut) or _fields.issuperset(c.__dict__.keys()):
ans.append(c.drop_in_memory_data())
else:
extra_attrs = {k: v for k, v in c.__dict__.items() if k not in _fields}
for k in extra_attrs:
delattr(c, k)
ans.append(c.drop_in_memory_data())
for k, v in extra_attrs.items():
setattr(ans[-1], k, v)
setattr(c, k, v)
return CutSet(ans)

0 comments on commit 4b1bfce

Please sign in to comment.