Skip to content

Commit

Permalink
Fixes for specaugment
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Sep 6, 2023
1 parent 1b1d653 commit a101f9e
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions i6_models/primitives/specaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def _random_mask(tensor: torch.Tensor, batch_axis: int, axis: int, min_num: int,

def zero_specaugment(
audio_features: torch.Tensor,
time_min_num_masks,
time_max_num_masks,
time_mask_max_size,
freq_min_num_masks,
freq_max_num_masks,
freq_mask_max_size,
time_min_num_masks: int,
time_max_num_masks: int,
time_mask_max_size: int,
freq_min_num_masks: int,
freq_max_num_masks: int,
freq_mask_max_size: int,
):
"""
Specaugment from legacy rossenbach/zeineldeen/zeyer attention setups (usually called specaugment_v2.py or so),
Expand All @@ -74,19 +74,23 @@ def zero_specaugment(
:return: masked audio features
"""
assert len(tensor.shape) == 3
tensor = _random_mask(audio_features, 0, 1, 2, time_num_masks, time_mask_max_size) # time masking
tensor = _random_mask(audio_features, 0, 2, 2, freq_num_masks, freq_mask_max_size) # freq masking
tensor = _random_mask(
audio_features, 0, 1, time_min_num_masks, time_max_num_masks, time_mask_max_size
) # time masking
tensor = _random_mask(
audio_features, 0, 2, freq_min_num_masks, freq_max_num_masks, freq_mask_max_size
) # freq masking
return tensor


def zero_specaugment_by_length(
audio_features,
time_mask_per_n_frames,
time_min_num_masks,
time_mask_max_size,
freq_min_num_masks,
freq_max_num_masks,
freq_mask_max_size,
audio_features: torch.Tensor,
time_mask_per_n_frames: int,
time_min_num_masks: int,
time_mask_max_size: int,
freq_min_num_masks: int,
freq_max_num_masks: int,
freq_mask_max_size: int,
):
"""
Convenience wrapper around zero_specaugment with time-length adaptive number of masks
Expand All @@ -103,10 +107,10 @@ def zero_specaugment_by_length(
:param freq_mask_max_size: maximum size of masks along F
:return: masked audio features
"""
return returnn_specaugment(
return zero_specaugment(
audio_features,
time_num_masks=audio_features.size(1) // time_mask_per_n_frames,
time_min_num_masks=time_min_num_masks,
time_max_num_masks=audio_features.size(1) // time_mask_per_n_frames,
time_mask_max_size=time_mask_max_size,
freq_min_num_masks=freq_min_num_masks,
freq_max_num_masks=freq_max_num_masks,
Expand Down

0 comments on commit a101f9e

Please sign in to comment.