diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 2dd4ae2980f1..35adc3130843 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -16,6 +16,7 @@ import torch.utils.data from lhotse import CutSet +from lhotse.cut import MixedCut from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors @@ -99,7 +100,7 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: prompt_lens=prompt_lens, prompted_transcript=prompts_with_answers, prompted_transcript_lens=prompts_with_answers_lens, - cuts=cuts.drop_in_memory_data(), + cuts=_drop_in_memory_data(cuts), ) def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]: @@ -111,3 +112,24 @@ def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple class ProbablyIncorrectLanguageKeyError(RuntimeError): pass + + +def _drop_in_memory_data( + cuts: CutSet, + _fields=frozenset(MixedCut.__dataclass_fields__.keys()), +) -> CutSet: + """Workaround for an edge case in cuts.drop_in_memory_data() on MixedCut with Lhotse<1.29.0""" + ans = [] + for c in cuts: + # Not a mixed cut or a mixed cut that wasn't assigned any extra attributes. + if not isinstance(c, MixedCut) or _fields.issuperset(c.__dict__.keys()): + ans.append(c.drop_in_memory_data()) + else: + extra_attrs = {k: v for k, v in c.__dict__.items() if k not in _fields} + for k in extra_attrs: + delattr(c, k) + ans.append(c.drop_in_memory_data()) + for k, v in extra_attrs.items(): + setattr(ans[-1], k, v) + setattr(c, k, v) + return CutSet(ans)