diff --git a/i6_models/primitives/specaugment.py b/i6_models/primitives/specaugment.py index 79272eaf..b4d4d761 100644 --- a/i6_models/primitives/specaugment.py +++ b/i6_models/primitives/specaugment.py @@ -51,12 +51,12 @@ def _random_mask(tensor: torch.Tensor, batch_axis: int, axis: int, min_num: int, def zero_specaugment( audio_features: torch.Tensor, - time_min_num_masks, - time_max_num_masks, - time_mask_max_size, - freq_min_num_masks, - freq_max_num_masks, - freq_mask_max_size, + time_min_num_masks: int, + time_max_num_masks: int, + time_mask_max_size: int, + freq_min_num_masks: int, + freq_max_num_masks: int, + freq_mask_max_size: int, ): """ Specaugment from legacy rossenbach/zeineldeen/zeyer attention setups (usually called specaugment_v2.py or so), @@ -74,19 +74,23 @@ def zero_specaugment( :return: masked audio features """ assert len(tensor.shape) == 3 - tensor = _random_mask(audio_features, 0, 1, 2, time_num_masks, time_mask_max_size) # time masking - tensor = _random_mask(audio_features, 0, 2, 2, freq_num_masks, freq_mask_max_size) # freq masking + tensor = _random_mask( + audio_features, 0, 1, time_min_num_masks, time_max_num_masks, time_mask_max_size + ) # time masking + tensor = _random_mask( + audio_features, 0, 2, freq_min_num_masks, freq_max_num_masks, freq_mask_max_size + ) # freq masking return tensor def zero_specaugment_by_length( - audio_features, - time_mask_per_n_frames, - time_min_num_masks, - time_mask_max_size, - freq_min_num_masks, - freq_max_num_masks, - freq_mask_max_size, + audio_features: torch.Tensor, + time_mask_per_n_frames: int, + time_min_num_masks: int, + time_mask_max_size: int, + freq_min_num_masks: int, + freq_max_num_masks: int, + freq_mask_max_size: int, ): """ Convenience wrapper around zero_specaugment with time-length adaptive number of masks @@ -103,10 +107,10 @@ def zero_specaugment_by_length( :param freq_mask_max_size: maximum size of masks along F :return: masked audio features """ - return returnn_specaugment( + return zero_specaugment( audio_features, - time_num_masks=audio_features.size(1) // time_mask_per_n_frames, time_min_num_masks=time_min_num_masks, + time_max_num_masks=audio_features.size(1) // time_mask_per_n_frames, time_mask_max_size=time_mask_max_size, freq_min_num_masks=freq_min_num_masks, freq_max_num_masks=freq_max_num_masks,