From b50b66eba5758de555d893c307f377e959d8ae90 Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Wed, 4 Dec 2024 22:59:31 +0800 Subject: [PATCH] Update i6_models/parts/best_rq/mask.py Co-authored-by: michelwi --- i6_models/parts/best_rq/mask.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 85346eb..d830e70 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -66,13 +66,8 @@ def forward( min_len = seq_len - num_mask - 1 mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) - mask_idc = np.asarray( - [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(self.mask_length)] - ) - mask_idcs.append(mask_idc) - - for i, mask_idc in enumerate(mask_idcs): - mask[i, mask_idc] = True + for j in mask_idc: + mask[i, j : j+self.mask_length] = True tensor[mask] = self.mask_emb.to(tensor.device)