diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 1dbe68589c0a..817938b758ae 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1237,15 +1237,11 @@ def __getitem__(self, index): np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal ) audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] - audio_signal_length = torch.tensor(audio_signal.shape[0]).long() - audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device) - target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate).to( - self.device - ) + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) targets = self.parse_rttm_for_targets_and_lens( rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len - ).to(self.device) + ) return audio_signal, audio_signal_length, targets, target_len diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 483ff5328ad0..e2ac0b09c81b 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -666,6 +666,7 @@ def test_batch( audio_signal, audio_signal_length, targets, target_lens = batch audio_signal = audio_signal.to(self.device) audio_signal_length = audio_signal_length.to(self.device) + targets = targets.to(self.device) preds = self.forward( audio_signal=audio_signal, audio_signal_length=audio_signal_length,