Skip to content

Commit

Permalink
Fixing the device assignment issues during inference (test_batch) in …
Browse files Browse the repository at this point in the history
…Sortformer model (#11671)

* Fixing the device assignment issues during inference (test_batch)

Signed-off-by: taejinp <[email protected]>

* Removing the commented code lines

Signed-off-by: taejinp <[email protected]>

---------

Signed-off-by: taejinp <[email protected]>
  • Loading branch information
tango4j authored Dec 20, 2024
1 parent 93f60fb commit 89c640c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
8 changes: 2 additions & 6 deletions nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,15 +1237,11 @@ def __getitem__(self, index):
np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal
)
audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)]

audio_signal_length = torch.tensor(audio_signal.shape[0]).long()
audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device)
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate).to(
self.device
)
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate)
targets = self.parse_rttm_for_targets_and_lens(
rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len
).to(self.device)
)
return audio_signal, audio_signal_length, targets, target_len


Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def test_batch(
audio_signal, audio_signal_length, targets, target_lens = batch
audio_signal = audio_signal.to(self.device)
audio_signal_length = audio_signal_length.to(self.device)
targets = targets.to(self.device)
preds = self.forward(
audio_signal=audio_signal,
audio_signal_length=audio_signal_length,
Expand Down

0 comments on commit 89c640c

Please sign in to comment.